diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 000000000..08891d83f --- /dev/null +++ b/.editorconfig @@ -0,0 +1,8 @@ +root = true + +[*] +indent_style = space +indent_size = 4 +end_of_line = lf +insert_final_newline = true +trim_trailing_whitespace = true \ No newline at end of file diff --git a/.github/release.yml b/.github/release.yml new file mode 100644 index 000000000..e29eb8464 --- /dev/null +++ b/.github/release.yml @@ -0,0 +1,14 @@ +changelog: + categories: + - title: SemVer Major + labels: + - ⚠️ semver/major + - title: SemVer Minor + labels: + - 🆕 semver/minor + - title: SemVer Patch + labels: + - 🔨 semver/patch + - title: Other Changes + labels: + - semver/none diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 000000000..d2fdc3809 --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,31 @@ +name: Main + +permissions: + contents: read + +on: + push: + branches: [main] + schedule: + - cron: "0 8,20 * * *" + +jobs: + unit-tests: + name: Unit tests + uses: apple/swift-nio/.github/workflows/unit_tests.yml@main + with: + linux_5_10_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -warnings-as-errors" + linux_6_0_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -warnings-as-errors" + linux_6_1_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -warnings-as-errors" + linux_6_2_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -warnings-as-errors" + linux_nightly_next_arguments_override: "--explicit-target-dependency-import-check error" + linux_nightly_main_arguments_override: "--explicit-target-dependency-import-check error" + + static-sdk: + name: Static SDK + # Workaround https://github.com/nektos/act/issues/1875 + uses: apple/swift-nio/.github/workflows/static_sdk.yml@main + + release-builds: + name: Release builds + uses: apple/swift-nio/.github/workflows/release_builds.yml@main diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml new file mode 100644 index 000000000..7423cd3b2 --- /dev/null +++ b/.github/workflows/pull_request.yml @@ -0,0 +1,40 @@ +name: PR + +permissions: + contents: read + +on: + pull_request: + types: [opened, reopened, synchronize] + +jobs: + soundness: + name: Soundness + uses: swiftlang/github-workflows/.github/workflows/soundness.yml@main + with: + license_header_check_project_name: "AsyncHTTPClient" + unit-tests: + name: Unit tests + uses: apple/swift-nio/.github/workflows/unit_tests.yml@main + with: + linux_5_10_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -warnings-as-errors" + linux_6_0_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -warnings-as-errors" + linux_6_1_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -warnings-as-errors" + linux_6_2_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -warnings-as-errors" + linux_nightly_next_arguments_override: "--explicit-target-dependency-import-check error" + linux_nightly_main_arguments_override: "--explicit-target-dependency-import-check error" + + cxx-interop: + name: Cxx interop + uses: apple/swift-nio/.github/workflows/cxx_interop.yml@main + with: + linux_5_9_enabled: false + + static-sdk: + name: Static SDK + # Workaround https://github.com/nektos/act/issues/1875 + uses: apple/swift-nio/.github/workflows/static_sdk.yml@main + + release-builds: + name: Release builds + uses: apple/swift-nio/.github/workflows/release_builds.yml@main diff --git a/.github/workflows/pull_request_label.yml b/.github/workflows/pull_request_label.yml new file mode 100644 index 000000000..d2da2f1ac --- /dev/null +++ b/.github/workflows/pull_request_label.yml @@ -0,0 +1,21 @@ +name: PR label + +permissions: + contents: read + +on: + pull_request: + types: [labeled, unlabeled, opened, reopened, synchronize] + +jobs: + semver-label-check: + name: Semantic version label check + runs-on: ubuntu-latest + timeout-minutes: 1 + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + persist-credentials: false + - name: Check for Semantic Version label + uses: apple/swift-nio/.github/actions/pull_request_semver_label_checker@main diff --git a/.licenseignore b/.licenseignore new file mode 100644 index 000000000..151ce9245 --- /dev/null +++ b/.licenseignore @@ -0,0 +1,39 @@ +.gitignore +**/.gitignore +.licenseignore +.gitattributes +.git-blame-ignore-revs +.mailfilter +.mailmap +.spi.yml +.swift-format +.editorconfig +.github/* +*.md +*.txt +*.yml +*.yaml +*.json +Package.swift +**/Package.swift +Package@swift-*.swift +**/Package@swift-*.swift +Package@-*.swift +**/Package@-*.swift +Package.resolved +**/Package.resolved +Makefile +*.modulemap +**/*.modulemap +**/*.docc/* +*.xcprivacy +**/*.xcprivacy +*.symlink +**/*.symlink +Dockerfile +**/Dockerfile +.dockerignore +Snippets/* +dev/git.commit.template +.unacceptablelanguageignore +Tests/AsyncHTTPClientTests/Resources/*.pem diff --git a/.spi.yml b/.spi.yml new file mode 100644 index 000000000..795484b35 --- /dev/null +++ b/.spi.yml @@ -0,0 +1,4 @@ +version: 1 +builder: + configs: + - documentation_targets: [AsyncHTTPClient] diff --git a/.swift-format b/.swift-format new file mode 100644 index 000000000..7e8ae7391 --- /dev/null +++ b/.swift-format @@ -0,0 +1,68 @@ +{ + "version" : 1, + "indentation" : { + "spaces" : 4 + }, + "tabWidth" : 4, + "fileScopedDeclarationPrivacy" : { + "accessLevel" : "private" + }, + "spacesAroundRangeFormationOperators" : false, + "indentConditionalCompilationBlocks" : false, + "indentSwitchCaseLabels" : false, + "lineBreakAroundMultilineExpressionChainComponents" : false, + "lineBreakBeforeControlFlowKeywords" : false, + "lineBreakBeforeEachArgument" : true, + "lineBreakBeforeEachGenericRequirement" : true, + "lineLength" : 120, + "maximumBlankLines" : 1, + "respectsExistingLineBreaks" : true, + "prioritizeKeepingFunctionOutputTogether" : true, + "noAssignmentInExpressions" : { + "allowedFunctions" : [ + "XCTAssertNoThrow", + "XCTAssertThrowsError" + ] + }, + "rules" : { + "AllPublicDeclarationsHaveDocumentation" : false, + "AlwaysUseLiteralForEmptyCollectionInit" : false, + "AlwaysUseLowerCamelCase" : false, + "AmbiguousTrailingClosureOverload" : true, + "BeginDocumentationCommentWithOneLineSummary" : false, + "DoNotUseSemicolons" : true, + "DontRepeatTypeInStaticProperties" : true, + "FileScopedDeclarationPrivacy" : true, + "FullyIndirectEnum" : true, + "GroupNumericLiterals" : true, + "IdentifiersMustBeASCII" : true, + "NeverForceUnwrap" : false, + "NeverUseForceTry" : false, + "NeverUseImplicitlyUnwrappedOptionals" : false, + "NoAccessLevelOnExtensionDeclaration" : true, + "NoAssignmentInExpressions" : true, + "NoBlockComments" : true, + "NoCasesWithOnlyFallthrough" : true, + "NoEmptyTrailingClosureParentheses" : true, + "NoLabelsInCasePatterns" : true, + "NoLeadingUnderscores" : false, + "NoParensAroundConditions" : true, + "NoVoidReturnOnFunctionSignature" : true, + "OmitExplicitReturns" : true, + "OneCasePerLine" : true, + "OneVariableDeclarationPerLine" : true, + "OnlyOneTrailingClosureArgument" : true, + "OrderedImports" : true, + "ReplaceForEachWithForLoop" : true, + "ReturnVoidInsteadOfEmptyTuple" : true, + "UseEarlyExits" : false, + "UseExplicitNilCheckInConditions" : false, + "UseLetInEveryBoundCaseVariable" : false, + "UseShorthandTypeNames" : true, + "UseSingleLinePropertyGetter" : false, + "UseSynthesizedInitializer" : false, + "UseTripleSlashForDocumentationComments" : true, + "UseWhereClausesInForLoops" : false, + "ValidateDocumentationComments" : false + } +} diff --git a/.swiftformat b/.swiftformat deleted file mode 100644 index c26e226e3..000000000 --- a/.swiftformat +++ /dev/null @@ -1,23 +0,0 @@ -# file options - ---swiftversion 5.4 ---exclude .build - -# format options - ---self insert ---patternlet inline ---ranges nospace ---stripunusedargs unnamed-only ---ifdef no-indent ---extensionacl on-declarations ---disable typeSugar # https://github.com/nicklockwood/SwiftFormat/issues/636 ---disable andOperator ---disable wrapMultilineStatementBraces ---disable enumNamespaces ---disable redundantExtensionACL ---disable redundantReturn ---disable preferKeyPath ---disable sortedSwitchCases - -# rules diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index c7a248828..76501d7d6 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -1,55 +1,5 @@ # Code of Conduct -To be a truly great community, AsyncHTTPClient needs to welcome developers from all walks of life, -with different backgrounds, and with a wide range of experience. A diverse and friendly -community will have more great ideas, more unique perspectives, and produce more great -code. We will work diligently to make the AsyncHTTPClient community welcoming to everyone. -To give clarity of what is expected of our members, AsyncHTTPClient has adopted the code of conduct -defined by [contributor-covenant.org](https://www.contributor-covenant.org). This document is used across many open source -communities, and we think it articulates our values well. The full text is copied below: +The code of conduct for this project can be found at https://swift.org/code-of-conduct. -### Contributor Code of Conduct v1.3 -As contributors and maintainers of this project, and in the interest of fostering an open and -welcoming community, we pledge to respect all people who contribute through reporting -issues, posting feature requests, updating documentation, submitting pull requests or patches, -and other activities. - -We are committed to making participation in this project a harassment-free experience for -everyone, regardless of level of experience, gender, gender identity and expression, sexual -orientation, disability, personal appearance, body size, race, ethnicity, age, religion, or -nationality. - -Examples of unacceptable behavior by participants include: -- The use of sexualized language or imagery -- Personal attacks -- Trolling or insulting/derogatory comments -- Public or private harassment -- Publishing other’s private information, such as physical or electronic addresses, without explicit permission -- Other unethical or unprofessional conduct - -Project maintainers have the right and responsibility to remove, edit, or reject comments, -commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of -Conduct, or to ban temporarily or permanently any contributor for other behaviors that they -deem inappropriate, threatening, offensive, or harmful. - -By adopting this Code of Conduct, project maintainers commit themselves to fairly and -consistently applying these principles to every aspect of managing this project. Project -maintainers who do not follow or enforce the Code of Conduct may be permanently removed -from the project team. - -This code of conduct applies both within project spaces and in public spaces when an -individual is representing the project or its community. - -Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by -contacting a project maintainer at [conduct@swiftserver.group](mailto:conduct@swiftserver.group). All complaints will be reviewed and -investigated and will result in a response that is deemed necessary and appropriate to the -circumstances. Maintainers are obligated to maintain confidentiality with regard to the reporter -of an incident. - -*This policy is adapted from the Contributor Code of Conduct [version 1.3.0](https://contributor-covenant.org/version/1/3/0/).* - -### Reporting -A working group of community members is committed to promptly addressing any [reported issues](mailto:conduct@swiftserver.group). -Working group members are volunteers appointed by the project lead, with a -preference for individuals with varied backgrounds and perspectives. Membership is expected -to change regularly, and may grow or shrink. + diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3803bb618..dddcb3ba4 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -65,10 +65,10 @@ We require that your commit messages match our template. The easiest way to do t git config commit.template dev/git.commit.template -### Make sure Tests work on Linux -AsyncHTTPClient uses XCTest to run tests on both macOS and Linux. While the macOS version of XCTest is able to use the Objective-C runtime to discover tests at execution time, the Linux version is not. -For this reason, whenever you add new tests **you have to run a script** that generates the hooks needed to run those tests on Linux, or our CI will complain that the tests are not all present on Linux. To do this, merely execute `ruby ./scripts/generate_linux_tests.rb` at the root of the package and check the changes it made. +### Run CI checks locally + +You can run the Github Actions workflows locally using [act](https://github.com/nektos/act). For detailed steps on how to do this please see [https://github.com/swiftlang/github-workflows?tab=readme-ov-file#running-workflows-locally](https://github.com/swiftlang/github-workflows?tab=readme-ov-file#running-workflows-locally). ## How to contribute your work diff --git a/Examples/GetHTML/GetHTML.swift b/Examples/GetHTML/GetHTML.swift index dfefa922b..ca3bacbea 100644 --- a/Examples/GetHTML/GetHTML.swift +++ b/Examples/GetHTML/GetHTML.swift @@ -18,12 +18,12 @@ import NIOCore @main struct GetHTML { static func main() async throws { - let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) + let httpClient = HTTPClient(eventLoopGroupProvider: .singleton) do { let request = HTTPClientRequest(url: "https://apple.com") let response = try await httpClient.execute(request, timeout: .seconds(30)) print("HTTP head", response) - let body = try await response.body.collect(upTo: 1024 * 1024) // 1 MB + let body = try await response.body.collect(upTo: 1024 * 1024) // 1 MB print(String(buffer: body)) } catch { print("request failed:", error) diff --git a/Examples/GetJSON/GetJSON.swift b/Examples/GetJSON/GetJSON.swift index ae58ffeaa..1af7a5144 100644 --- a/Examples/GetJSON/GetJSON.swift +++ b/Examples/GetJSON/GetJSON.swift @@ -33,12 +33,12 @@ struct Comic: Codable { @main struct GetJSON { static func main() async throws { - let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) + let httpClient = HTTPClient(eventLoopGroupProvider: .singleton) do { let request = HTTPClientRequest(url: "https://xkcd.com/info.0.json") let response = try await httpClient.execute(request, timeout: .seconds(30)) print("HTTP head", response) - let body = try await response.body.collect(upTo: 1024 * 1024) // 1 MB + let body = try await response.body.collect(upTo: 1024 * 1024) // 1 MB // we use an overload defined in `NIOFoundationCompat` for `decode(_:from:)` to // efficiently decode from a `ByteBuffer` let comic = try JSONDecoder().decode(Comic.self, from: body) diff --git a/Examples/Package.swift b/Examples/Package.swift index 696092cba..9986b17b5 100644 --- a/Examples/Package.swift +++ b/Examples/Package.swift @@ -43,7 +43,8 @@ let package = Package( dependencies: [ .product(name: "AsyncHTTPClient", package: "async-http-client"), .product(name: "NIOCore", package: "swift-nio"), - ], path: "GetHTML" + ], + path: "GetHTML" ), .executableTarget( name: "GetJSON", @@ -51,14 +52,16 @@ let package = Package( .product(name: "AsyncHTTPClient", package: "async-http-client"), .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOFoundationCompat", package: "swift-nio"), - ], path: "GetJSON" + ], + path: "GetJSON" ), .executableTarget( name: "StreamingByteCounter", dependencies: [ .product(name: "AsyncHTTPClient", package: "async-http-client"), .product(name: "NIOCore", package: "swift-nio"), - ], path: "StreamingByteCounter" + ], + path: "StreamingByteCounter" ), ] ) diff --git a/Examples/StreamingByteCounter/StreamingByteCounter.swift b/Examples/StreamingByteCounter/StreamingByteCounter.swift index dc340d14b..ecfb48776 100644 --- a/Examples/StreamingByteCounter/StreamingByteCounter.swift +++ b/Examples/StreamingByteCounter/StreamingByteCounter.swift @@ -18,7 +18,7 @@ import NIOCore @main struct StreamingByteCounter { static func main() async throws { - let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) + let httpClient = HTTPClient(eventLoopGroupProvider: .singleton) do { let request = HTTPClientRequest(url: "https://apple.com") let response = try await httpClient.execute(request, timeout: .seconds(30)) diff --git a/NOTICE.txt b/NOTICE.txt index 095a11740..86a969171 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -50,13 +50,13 @@ This product contains a derivation of the Tony Stone's 'process_test_files.rb'. * https://www.apache.org/licenses/LICENSE-2.0 * HOMEPAGE: * https://github.com/tonystone/build-tools/commit/6c417b7569df24597a48a9aa7b505b636e8f73a1 - * https://github.com/tonystone/build-tools/blob/master/source/xctest_tool.rb + * https://github.com/tonystone/build-tools/blob/cf3440f43bde2053430285b4ed0709c865892eb5/source/xctest_tool.rb --- This product contains a derivation of Fabian Fett's 'Base64.swift'. * LICENSE (Apache License 2.0): - * https://github.com/fabianfett/swift-base64-kit/blob/master/LICENSE + * https://github.com/swift-extras/swift-extras-base64/blob/b8af49699d59ad065b801715a5009619100245ca/LICENSE * HOMEPAGE: * https://github.com/fabianfett/swift-base64-kit diff --git a/Package.swift b/Package.swift index 5deb0de31..aad0c1c53 100644 --- a/Package.swift +++ b/Package.swift @@ -1,4 +1,4 @@ -// swift-tools-version:5.4 +// swift-tools-version:6.0 //===----------------------------------------------------------------------===// // // This source file is part of the AsyncHTTPClient open source project @@ -15,43 +15,71 @@ import PackageDescription +let strictConcurrencyDevelopment = false + +let strictConcurrencySettings: [SwiftSetting] = { + var initialSettings: [SwiftSetting] = [] + + if strictConcurrencyDevelopment { + // -warnings-as-errors here is a workaround so that IDE-based development can + // get tripped up on -require-explicit-sendable. + initialSettings.append(.unsafeFlags(["-Xfrontend", "-require-explicit-sendable", "-warnings-as-errors"])) + } + + return initialSettings +}() + let package = Package( name: "async-http-client", products: [ - .library(name: "AsyncHTTPClient", targets: ["AsyncHTTPClient"]), + .library(name: "AsyncHTTPClient", targets: ["AsyncHTTPClient"]) ], dependencies: [ - .package(url: "https://github.com/apple/swift-nio.git", from: "2.38.0"), - .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.14.1"), - .package(url: "https://github.com/apple/swift-nio-http2.git", from: "1.19.0"), - .package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.10.0"), - .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.11.4"), - .package(url: "https://github.com/apple/swift-log.git", from: "1.4.0"), + .package(url: "https://github.com/apple/swift-nio.git", from: "2.81.0"), + .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.30.0"), + .package(url: "https://github.com/apple/swift-nio-http2.git", from: "1.36.0"), + .package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.26.0"), + .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.24.0"), + .package(url: "https://github.com/apple/swift-log.git", from: "1.7.1"), + .package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.2"), + .package(url: "https://github.com/apple/swift-algorithms.git", from: "1.0.0"), + .package(url: "https://github.com/apple/swift-distributed-tracing.git", from: "1.3.0"), ], targets: [ - .target(name: "CAsyncHTTPClient"), + .target( + name: "CAsyncHTTPClient", + cSettings: [ + .define("_GNU_SOURCE") + ] + ), .target( name: "AsyncHTTPClient", dependencies: [ .target(name: "CAsyncHTTPClient"), .product(name: "NIO", package: "swift-nio"), + .product(name: "NIOTLS", package: "swift-nio"), .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOPosix", package: "swift-nio"), .product(name: "NIOHTTP1", package: "swift-nio"), .product(name: "NIOConcurrencyHelpers", package: "swift-nio"), - .product(name: "NIOFoundationCompat", package: "swift-nio"), .product(name: "NIOHTTP2", package: "swift-nio-http2"), .product(name: "NIOSSL", package: "swift-nio-ssl"), .product(name: "NIOHTTPCompression", package: "swift-nio-extras"), .product(name: "NIOSOCKS", package: "swift-nio-extras"), .product(name: "NIOTransportServices", package: "swift-nio-transport-services"), + .product(name: "Atomics", package: "swift-atomics"), + .product(name: "Algorithms", package: "swift-algorithms"), + // Observability support .product(name: "Logging", package: "swift-log"), - ] + .product(name: "Tracing", package: "swift-distributed-tracing"), + ], + swiftSettings: strictConcurrencySettings ), .testTarget( name: "AsyncHTTPClientTests", dependencies: [ .target(name: "AsyncHTTPClient"), + .product(name: "NIOTLS", package: "swift-nio"), .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOConcurrencyHelpers", package: "swift-nio"), .product(name: "NIOEmbedded", package: "swift-nio"), @@ -60,8 +88,37 @@ let package = Package( .product(name: "NIOSSL", package: "swift-nio-ssl"), .product(name: "NIOHTTP2", package: "swift-nio-http2"), .product(name: "NIOSOCKS", package: "swift-nio-extras"), + .product(name: "Atomics", package: "swift-atomics"), + .product(name: "Algorithms", package: "swift-algorithms"), + // Observability support .product(name: "Logging", package: "swift-log"), - ] + .product(name: "InMemoryLogging", package: "swift-log"), + .product(name: "Tracing", package: "swift-distributed-tracing"), + .product(name: "InMemoryTracing", package: "swift-distributed-tracing"), + ], + resources: [ + .copy("Resources/self_signed_cert.pem"), + .copy("Resources/self_signed_key.pem"), + .copy("Resources/example.com.cert.pem"), + .copy("Resources/example.com.private-key.pem"), + ], + swiftSettings: strictConcurrencySettings ), ] ) + +// --- STANDARD CROSS-REPO SETTINGS DO NOT EDIT --- // +for target in package.targets { + switch target.type { + case .regular, .test, .executable: + var settings = target.swiftSettings ?? [] + // https://github.com/swiftlang/swift-evolution/blob/main/proposals/0444-member-import-visibility.md + settings.append(.enableUpcomingFeature("MemberImportVisibility")) + target.swiftSettings = settings + case .macro, .plugin, .system, .binary: + () // not applicable + @unknown default: + () // we don't know what to do here, do nothing + } +} +// --- END: STANDARD CROSS-REPO SETTINGS DO NOT EDIT --- // diff --git a/README.md b/README.md index 6dad76de3..b557e58fa 100644 --- a/README.md +++ b/README.md @@ -1,21 +1,15 @@ # AsyncHTTPClient -This package provides simple HTTP Client library built on top of SwiftNIO. +This package provides an HTTP Client library built on top of SwiftNIO. This library provides the following: -- First class support for Swift Concurrency (since version 1.9.0) +- First class support for Swift Concurrency - Asynchronous and non-blocking request methods - Simple follow-redirects (cookie headers are dropped) - Streaming body download - TLS support -- Automatic HTTP/2 over HTTPS (since version 1.7.0) +- Automatic HTTP/2 over HTTPS - Cookie parsing (but not storage) ---- - -**NOTE**: You will need [Xcode 13.2](https://apps.apple.com/gb/app/xcode/id497799835?mt=12) or [Swift 5.5.2](https://swift.org/download/#swift-552) to try out `AsyncHTTPClient`s new async/await APIs. - ---- - ## Getting Started #### Adding the dependency @@ -33,18 +27,12 @@ and `AsyncHTTPClient` dependency to your target: The code snippet below illustrates how to make a simple GET request to a remote server. -Please note that the example will spawn a new `EventLoopGroup` which will _create fresh threads_ which is a very costly operation. In a real-world application that uses [SwiftNIO](https://github.com/apple/swift-nio) for other parts of your application (for example a web server), please prefer `eventLoopGroupProvider: .shared(myExistingEventLoopGroup)` to share the `EventLoopGroup` used by AsyncHTTPClient with other parts of your application. - -If your application does not use SwiftNIO yet, it is acceptable to use `eventLoopGroupProvider: .createNew` but please make sure to share the returned `HTTPClient` instance throughout your whole application. Do not create a large number of `HTTPClient` instances with `eventLoopGroupProvider: .createNew`, this is very wasteful and might exhaust the resources of your program. - ```swift import AsyncHTTPClient -let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) - /// MARK: - Using Swift Concurrency let request = HTTPClientRequest(url: "https://apple.com/") -let response = try await httpClient.execute(request, timeout: .seconds(30)) +let response = try await HTTPClient.shared.execute(request, timeout: .seconds(30)) print("HTTP head", response) if response.status == .ok { let body = try await response.body.collect(upTo: 1024 * 1024) // 1 MB @@ -55,7 +43,7 @@ if response.status == .ok { /// MARK: - Using SwiftNIO EventLoopFuture -httpClient.get(url: "https://apple.com/").whenComplete { result in +HTTPClient.shared.get(url: "https://apple.com/").whenComplete { result in switch result { case .failure(let error): // process error @@ -69,7 +57,8 @@ httpClient.get(url: "https://apple.com/").whenComplete { result in } ``` -You should always shut down `HTTPClient` instances you created using `try httpClient.syncShutdown()`. Please note that you must not call `httpClient.syncShutdown` before all requests of the HTTP client have finished, or else the in-flight requests will likely fail because their network connections are interrupted. +If you create your own `HTTPClient` instances, you should shut them down using `httpClient.shutdown()` when you're done using them. Failing to do so will leak resources. + Please note that you must not call `httpClient.shutdown` before all requests of the HTTP client have finished, or else the in-flight requests will likely fail because their network connections are interrupted. ### async/await examples @@ -84,14 +73,13 @@ The default HTTP Method is `GET`. In case you need to have more control over the ```swift import AsyncHTTPClient -let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) do { var request = HTTPClientRequest(url: "https://apple.com/") request.method = .POST request.headers.add(name: "User-Agent", value: "Swift HTTPClient") request.body = .bytes(ByteBuffer(string: "some data")) - let response = try await httpClient.execute(request, timeout: .seconds(30)) + let response = try await HTTPClient.shared.execute(request, timeout: .seconds(30)) if response.status == .ok { // handle response } else { @@ -100,8 +88,6 @@ do { } catch { // handle error } -// it's important to shutdown the httpClient after all requests are done, even if one failed -try await httpClient.shutdown() ``` #### Using SwiftNIO EventLoopFuture @@ -109,16 +95,11 @@ try await httpClient.shutdown() ```swift import AsyncHTTPClient -let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) -defer { - try? httpClient.syncShutdown() -} - var request = try HTTPClient.Request(url: "https://apple.com/", method: .POST) request.headers.add(name: "User-Agent", value: "Swift HTTPClient") request.body = .string("some-body") -httpClient.execute(request: request).whenComplete { result in +HTTPClient.shared.execute(request: request).whenComplete { result in switch result { case .failure(let error): // process error @@ -133,9 +114,11 @@ httpClient.execute(request: request).whenComplete { result in ``` ### Redirects following -Enable follow-redirects behavior using the client configuration: + +The globally shared instance `HTTPClient.shared` follows redirects by default. If you create your own `HTTPClient`, you can enable the follow-redirects behavior using the client configuration: + ```swift -let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, +let httpClient = HTTPClient(eventLoopGroupProvider: .singleton, configuration: HTTPClient.Configuration(followRedirects: true)) ``` @@ -143,7 +126,7 @@ let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, Timeouts (connect and read) can also be set using the client configuration: ```swift let timeout = HTTPClient.Configuration.Timeout(connect: .seconds(1), read: .seconds(1)) -let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, +let httpClient = HTTPClient(eventLoopGroupProvider: .singleton, configuration: HTTPClient.Configuration(timeout: timeout)) ``` or on a per-request basis: @@ -152,15 +135,14 @@ httpClient.execute(request: request, deadline: .now() + .milliseconds(1)) ``` ### Streaming -When dealing with larger amount of data, it's critical to stream the response body instead of aggregating in-memory. +When dealing with larger amount of data, it's critical to stream the response body instead of aggregating in-memory. The following example demonstrates how to count the number of bytes in a streaming response body: #### Using Swift Concurrency ```swift -let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) do { let request = HTTPClientRequest(url: "https://apple.com/") - let response = try await httpClient.execute(request, timeout: .seconds(30)) + let response = try await HTTPClient.shared.execute(request, timeout: .seconds(30)) print("HTTP head", response) // if defined, the content-length headers announces the size of the body @@ -172,7 +154,7 @@ do { for try await buffer in response.body { // for this example, we are just interested in the size of the fragment receivedBytes += buffer.readableBytes - + if let expectedBytes = expectedBytes { // if the body size is known, we calculate a progress indicator let progress = Double(receivedBytes) / Double(expectedBytes) @@ -181,10 +163,8 @@ do { } print("did receive \(receivedBytes) bytes") } catch { - print("request failed:", error) + print("request failed:", error) } -// it is important to shutdown the httpClient after all requests are done, even if one failed -try await httpClient.shutdown() ``` #### Using HTTPClientResponseDelegate and SwiftNIO EventLoopFuture @@ -211,17 +191,17 @@ class CountingDelegate: HTTPClientResponseDelegate { } func didReceiveHead( - task: HTTPClient.Task, + task: HTTPClient.Task, _ head: HTTPResponseHead ) -> EventLoopFuture { - // this is executed when we receive HTTP response head part of the request - // (it contains response code and headers), called once in case backpressure + // this is executed when we receive HTTP response head part of the request + // (it contains response code and headers), called once in case backpressure // is needed, all reads will be paused until returned future is resolved return task.eventLoop.makeSucceededFuture(()) } func didReceiveBodyPart( - task: HTTPClient.Task, + task: HTTPClient.Task, _ buffer: ByteBuffer ) -> EventLoopFuture { // this is executed when we receive parts of the response body, could be called zero or more times @@ -244,7 +224,7 @@ class CountingDelegate: HTTPClientResponseDelegate { let request = try HTTPClient.Request(url: "https://apple.com/") let delegate = CountingDelegate() -httpClient.execute(request: request, delegate: delegate).futureResult.whenSuccess { count in +HTTPClient.shared.execute(request: request, delegate: delegate).futureResult.whenSuccess { count in print(count) } ``` @@ -257,7 +237,6 @@ asynchronously, while reporting the download progress at the same time, like in example: ```swift -let client = HTTPClient(eventLoopGroupProvider: .createNew) let request = try HTTPClient.Request( url: "https://swift.org/builds/development/ubuntu1804/latest-build.yml" ) @@ -269,7 +248,7 @@ let delegate = try FileDownloadDelegate(path: "/tmp/latest-build.yml", reportPro print("Downloaded \($0.receivedBytes) bytes so far") }) -client.execute(request: request, delegate: delegate).futureResult +HTTPClient.shared.execute(request: request, delegate: delegate).futureResult .whenSuccess { progress in if let totalBytes = progress.totalBytes { print("Final total bytes count: \(totalBytes)") @@ -281,21 +260,19 @@ client.execute(request: request, delegate: delegate).futureResult ### Unix Domain Socket Paths Connecting to servers bound to socket paths is easy: ```swift -let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) -httpClient.execute( - .GET, - socketPath: "/tmp/myServer.socket", +HTTPClient.shared.execute( + .GET, + socketPath: "/tmp/myServer.socket", urlPath: "/path/to/resource" ).whenComplete (...) ``` Connecting over TLS to a unix domain socket path is possible as well: ```swift -let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) -httpClient.execute( - .POST, - secureSocketPath: "/tmp/myServer.socket", - urlPath: "/path/to/resource", +HTTPClient.shared.execute( + .POST, + secureSocketPath: "/tmp/myServer.socket", + urlPath: "/path/to/resource", body: .string("hello") ).whenComplete (...) ``` @@ -303,11 +280,11 @@ httpClient.execute( Direct URLs can easily be constructed to be executed in other scenarios: ```swift let socketPathBasedURL = URL( - httpURLWithSocketPath: "/tmp/myServer.socket", + httpURLWithSocketPath: "/tmp/myServer.socket", uri: "/path/to/resource" ) let secureSocketPathBasedURL = URL( - httpsURLWithSocketPath: "/tmp/myServer.socket", + httpsURLWithSocketPath: "/tmp/myServer.socket", uri: "/path/to/resource" ) ``` @@ -318,7 +295,7 @@ The exclusive use of HTTP/1 is possible by setting `httpVersion` to `.http1Only` var configuration = HTTPClient.Configuration() configuration.httpVersion = .http1Only let client = HTTPClient( - eventLoopGroupProvider: .createNew, + eventLoopGroupProvider: .singleton, configuration: configuration ) ``` @@ -326,3 +303,20 @@ let client = HTTPClient( ## Security Please have a look at [SECURITY.md](SECURITY.md) for AsyncHTTPClient's security process. + +## Supported Versions + +The most recent versions of AsyncHTTPClient support Swift 6.0 and newer. The minimum Swift version supported by AsyncHTTPClient releases are detailed below: + +AsyncHTTPClient | Minimum Swift Version +--------------------|---------------------- +`1.0.0 ..< 1.5.0` | 5.0 +`1.5.0 ..< 1.10.0` | 5.2 +`1.10.0 ..< 1.13.0` | 5.4 +`1.13.0 ..< 1.18.0` | 5.5.2 +`1.18.0 ..< 1.20.0` | 5.6 +`1.20.0 ..< 1.21.0` | 5.7 +`1.21.0 ..< 1.26.0` | 5.8 +`1.26.0 ..< 1.27.0` | 5.9 +`1.27.0 ..< 1.30.0` | 5.10 +`1.30.0 ...` | 6.0 diff --git a/Sources/AsyncHTTPClient/AsyncAwait/AnyAsyncSequence.swift b/Sources/AsyncHTTPClient/AsyncAwait/AnyAsyncSequence.swift new file mode 100644 index 000000000..fbcc82ec1 --- /dev/null +++ b/Sources/AsyncHTTPClient/AsyncAwait/AnyAsyncSequence.swift @@ -0,0 +1,51 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2022 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +@usableFromInline +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +struct AnyAsyncSequence: Sendable, AsyncSequence { + @usableFromInline typealias AsyncIteratorNextCallback = () async throws -> Element? + + @usableFromInline struct AsyncIterator: AsyncIteratorProtocol { + @usableFromInline let nextCallback: AsyncIteratorNextCallback + + @inlinable init(nextCallback: @escaping AsyncIteratorNextCallback) { + self.nextCallback = nextCallback + } + + @inlinable mutating func next() async throws -> Element? { + try await self.nextCallback() + } + } + + @usableFromInline var makeAsyncIteratorCallback: @Sendable () -> AsyncIteratorNextCallback + + @inlinable init( + _ asyncSequence: SequenceOfBytes + ) where SequenceOfBytes: AsyncSequence & Sendable, SequenceOfBytes.Element == Element { + self.makeAsyncIteratorCallback = { + var iterator = asyncSequence.makeAsyncIterator() + return { + try await iterator.next() + } + } + } + + @inlinable func makeAsyncIterator() -> AsyncIterator { + .init(nextCallback: self.makeAsyncIteratorCallback()) + } +} + +@available(*, unavailable) +extension AnyAsyncSequence.AsyncIterator: Sendable {} diff --git a/Sources/AsyncHTTPClient/AsyncAwait/AnyAsyncSequenceProucerDelete.swift b/Sources/AsyncHTTPClient/AsyncAwait/AnyAsyncSequenceProucerDelete.swift new file mode 100644 index 000000000..1e35df7f2 --- /dev/null +++ b/Sources/AsyncHTTPClient/AsyncAwait/AnyAsyncSequenceProucerDelete.swift @@ -0,0 +1,37 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2023 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIOCore + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +@usableFromInline +struct AnyAsyncSequenceProducerDelegate: NIOAsyncSequenceProducerDelegate { + @usableFromInline + var delegate: NIOAsyncSequenceProducerDelegate + + @inlinable + init(_ delegate: Delegate) { + self.delegate = delegate + } + + @inlinable + func produceMore() { + self.delegate.produceMore() + } + + @inlinable + func didTerminate() { + self.delegate.didTerminate() + } +} diff --git a/Sources/AsyncHTTPClient/AsyncAwait/AsyncLazySequence.swift b/Sources/AsyncHTTPClient/AsyncAwait/AsyncLazySequence.swift new file mode 100644 index 000000000..fe37dd5e7 --- /dev/null +++ b/Sources/AsyncHTTPClient/AsyncAwait/AsyncLazySequence.swift @@ -0,0 +1,52 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2022 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +@usableFromInline +struct AsyncLazySequence: AsyncSequence { + @usableFromInline typealias Element = Base.Element + @usableFromInline struct AsyncIterator: AsyncIteratorProtocol { + @usableFromInline var iterator: Base.Iterator + @inlinable init(iterator: Base.Iterator) { + self.iterator = iterator + } + + @inlinable mutating func next() async throws -> Base.Element? { + self.iterator.next() + } + } + + @usableFromInline var base: Base + + @inlinable init(base: Base) { + self.base = base + } + + @inlinable func makeAsyncIterator() -> AsyncIterator { + .init(iterator: self.base.makeIterator()) + } +} + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension AsyncLazySequence: Sendable where Base: Sendable {} +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension AsyncLazySequence.AsyncIterator: Sendable where Base.Iterator: Sendable {} + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension Sequence { + /// Turns `self` into an `AsyncSequence` by vending each element of `self` asynchronously. + @inlinable var async: AsyncLazySequence { + .init(base: self) + } +} diff --git a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClient+execute.swift b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClient+execute.swift index 043ad510b..bbf8c948c 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClient+execute.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClient+execute.swift @@ -12,11 +12,12 @@ // //===----------------------------------------------------------------------===// -#if compiler(>=5.5.2) && canImport(_Concurrency) -import struct Foundation.URL import Logging import NIOCore import NIOHTTP1 +import Tracing + +import struct Foundation.URL @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension HTTPClient { @@ -26,18 +27,24 @@ extension HTTPClient { /// - request: HTTP request to execute. /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. + /// + /// - warning: This method may violates Structured Concurrency because it returns a `HTTPClientResponse` that needs to be + /// streamed by the user. This means the request, the connection and other resources are still alive when the request returns. + /// /// - Returns: The response to the request. Note that the `body` of the response may not yet have been fully received. public func execute( _ request: HTTPClientRequest, deadline: NIODeadline, logger: Logger? = nil ) async throws -> HTTPClientResponse { - try await self.executeAndFollowRedirectsIfNeeded( - request, - deadline: deadline, - logger: logger ?? Self.loggingDisabled, - redirectState: RedirectState(self.configuration.redirectConfiguration.mode, initialURL: request.url) - ) + try await withRequestSpan(request) { + try await self.executeAndFollowRedirectsIfNeeded( + request, + deadline: deadline, + logger: logger ?? Self.loggingDisabled, + redirectState: RedirectState(self.configuration.redirectConfiguration.mode, initialURL: request.url) + ) + } } } @@ -51,6 +58,10 @@ extension HTTPClient { /// - request: HTTP request to execute. /// - timeout: time the the request has to complete. /// - logger: The logger to use for this request. + /// + /// - warning: This method may violates Structured Concurrency because it returns a `HTTPClientResponse` that needs to be + /// streamed by the user. This means the request, the connection and other resources are still alive when the request returns. + /// /// - Returns: The response to the request. Note that the `body` of the response may not yet have been fully received. public func execute( _ request: HTTPClientRequest, @@ -67,6 +78,8 @@ extension HTTPClient { @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension HTTPClient { + /// - warning: This method may violates Structured Concurrency because it returns a `HTTPClientResponse` that needs to be + /// streamed by the user. This means the request, the connection and other resources are still alive when the request returns. private func executeAndFollowRedirectsIfNeeded( _ request: HTTPClientRequest, deadline: NIODeadline, @@ -75,22 +88,47 @@ extension HTTPClient { ) async throws -> HTTPClientResponse { var currentRequest = request var currentRedirectState = redirectState + var history: [HTTPClientRequestResponse] = [] // this loop is there to follow potential redirects while true { - let preparedRequest = try HTTPClientRequest.Prepared(currentRequest) - let response = try await executeCancellable(preparedRequest, deadline: deadline, logger: logger) + let preparedRequest = + try HTTPClientRequest.Prepared( + currentRequest, + dnsOverride: configuration.dnsOverride, + tracing: self.configuration.tracing + ) + let response = try await { + var response = try await self.executeCancellable(preparedRequest, deadline: deadline, logger: logger) + + history.append( + .init( + request: currentRequest, + responseHead: .init( + version: response.version, + status: response.status, + headers: response.headers + ) + ) + ) + + response.history = history + + return response + }() guard var redirectState = currentRedirectState else { // a `nil` redirectState means we should not follow redirects return response } - guard let redirectURL = response.headers.extractRedirectTarget( - status: response.status, - originalURL: preparedRequest.url, - originalScheme: preparedRequest.poolKey.scheme - ) else { + guard + let redirectURL = response.headers.extractRedirectTarget( + status: response.status, + originalURL: preparedRequest.url, + originalScheme: preparedRequest.poolKey.scheme + ) + else { // response does not want a redirect return response } @@ -114,6 +152,8 @@ extension HTTPClient { } } + /// - warning: This method may violates Structured Concurrency because it returns a `HTTPClientResponse` that needs to be + /// streamed by the user. This means the request, the connection and other resources are still alive when the request returns. private func executeCancellable( _ request: HTTPClientRequest.Prepared, deadline: NIODeadline, @@ -121,31 +161,35 @@ extension HTTPClient { ) async throws -> HTTPClientResponse { let cancelHandler = TransactionCancelHandler() - return try await withTaskCancellationHandler(operation: { () async throws -> HTTPClientResponse in - let eventLoop = self.eventLoopGroup.any() - let deadlineTask = eventLoop.scheduleTask(deadline: deadline) { - cancelHandler.cancel(reason: .deadlineExceeded) - } - defer { - deadlineTask.cancel() - } - return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) -> Void in - let transaction = Transaction( - request: request, - requestOptions: .init(idleReadTimeout: nil), - logger: logger, - connectionDeadline: deadline, - preferredEventLoop: eventLoop, - responseContinuation: continuation - ) + return try await withTaskCancellationHandler( + operation: { () async throws -> HTTPClientResponse in + let eventLoop = self.eventLoopGroup.any() + let deadlineTask = eventLoop.scheduleTask(deadline: deadline) { + cancelHandler.cancel(reason: .deadlineExceeded) + } + defer { + deadlineTask.cancel() + } + return try await withCheckedThrowingContinuation { + (continuation: CheckedContinuation) -> Void in + let transaction = Transaction( + request: request, + requestOptions: .fromClientConfiguration(self.configuration), + logger: logger, + connectionDeadline: .now() + (self.configuration.timeout.connectionCreationTimeout), + preferredEventLoop: eventLoop, + responseContinuation: continuation + ) - cancelHandler.registerTransaction(transaction) + cancelHandler.registerTransaction(transaction) - self.poolManager.executeRequest(transaction) + self.poolManager.executeRequest(transaction) + } + }, + onCancel: { + cancelHandler.cancel(reason: .taskCanceled) } - }, onCancel: { - cancelHandler.cancel(reason: .taskCanceled) - }) + ) } } @@ -215,5 +259,3 @@ private actor TransactionCancelHandler { } } } - -#endif diff --git a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClient+shutdown.swift b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClient+shutdown.swift index 4e7090dbf..43020c3e5 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClient+shutdown.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClient+shutdown.swift @@ -12,7 +12,7 @@ // //===----------------------------------------------------------------------===// -#if compiler(>=5.5.2) && canImport(_Concurrency) +import NIOCore extension HTTPClient { /// Shuts down the client and `EventLoopGroup` if it was created by the client. @@ -30,5 +30,3 @@ extension HTTPClient { } } } - -#endif diff --git a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClient+tracing.swift b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClient+tracing.swift new file mode 100644 index 000000000..0be737619 --- /dev/null +++ b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClient+tracing.swift @@ -0,0 +1,41 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIOHTTP1 +import Tracing + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension HTTPClient { + @inlinable + func withRequestSpan( + _ request: HTTPClientRequest, + _ body: () async throws -> HTTPClientResponse + ) async rethrows -> HTTPClientResponse { + guard let tracer = self.tracer else { + return try await body() + } + + return try await tracer.withSpan(request.method.rawValue, ofKind: .client) { span in + let keys = self.configuration.tracing.attributeKeys + span.attributes[keys.requestMethod] = request.method.rawValue + // TODO: set more attributes on the span + let response = try await body() + + // set response span attributes + TracingSupport.handleResponseStatusCode(span, response.status, keys: tracing.attributeKeys) + + return response + } + } +} diff --git a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest+Prepared.swift b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest+Prepared.swift index de09df5b8..b5649cf90 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest+Prepared.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest+Prepared.swift @@ -12,25 +12,47 @@ // //===----------------------------------------------------------------------===// -#if compiler(>=5.5.2) && canImport(_Concurrency) -import struct Foundation.URL +import Instrumentation +import NIOCore import NIOHTTP1 +import NIOSSL +import ServiceContextModule + +import struct Foundation.URL @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension HTTPClientRequest { struct Prepared { + enum Body { + case asyncSequence( + length: RequestBodyLength, + makeAsyncIterator: @Sendable () -> ((ByteBufferAllocator) async throws -> ByteBuffer?) + ) + case sequence( + length: RequestBodyLength, + canBeConsumedMultipleTimes: Bool, + makeCompleteBody: (ByteBufferAllocator) -> ByteBuffer + ) + case byteBuffer(ByteBuffer) + } + var url: URL var poolKey: ConnectionPool.Key var requestFramingMetadata: RequestFramingMetadata var head: HTTPRequestHead var body: Body? + var tlsConfiguration: TLSConfiguration? } } @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension HTTPClientRequest.Prepared { - init(_ request: HTTPClientRequest) throws { - guard let url = URL(string: request.url) else { + init( + _ request: HTTPClientRequest, + dnsOverride: [String: String] = [:], + tracing: HTTPClient.TracingConfiguration? = nil + ) throws { + guard !request.url.isEmpty, let url = URL(string: request.url) else { throw HTTPClientError.invalidURL } @@ -38,6 +60,12 @@ extension HTTPClientRequest.Prepared { var headers = request.headers headers.addHostIfNeeded(for: deconstructedURL) + if let tracer = tracing?.tracer, + let context = ServiceContext.current + { + tracer.inject(context, into: &headers, using: HTTPHeadersInjector.shared) + } + let metadata = try headers.validateAndSetTransportFraming( method: request.method, bodyLength: .init(request.body) @@ -45,7 +73,7 @@ extension HTTPClientRequest.Prepared { self.init( url: url, - poolKey: .init(url: deconstructedURL, tlsConfiguration: nil), + poolKey: .init(url: deconstructedURL, tlsConfiguration: request.tlsConfiguration, dnsOverride: dnsOverride), requestFramingMetadata: metadata, head: .init( version: .http1_1, @@ -53,11 +81,30 @@ extension HTTPClientRequest.Prepared { uri: deconstructedURL.uri, headers: headers ), - body: request.body + body: request.body.map { .init($0) }, + tlsConfiguration: request.tlsConfiguration ) } } +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension HTTPClientRequest.Prepared.Body { + init(_ body: HTTPClientRequest.Body) { + switch body.mode { + case .asyncSequence(let length, let makeAsyncIterator): + self = .asyncSequence(length: length, makeAsyncIterator: makeAsyncIterator) + case .sequence(let length, let canBeConsumedMultipleTimes, let makeCompleteBody): + self = .sequence( + length: length, + canBeConsumedMultipleTimes: canBeConsumedMultipleTimes, + makeCompleteBody: makeCompleteBody + ) + case .byteBuffer(let byteBuffer): + self = .byteBuffer(byteBuffer) + } + } +} + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension RequestBodyLength { init(_ body: HTTPClientRequest.Body?) { @@ -65,7 +112,7 @@ extension RequestBodyLength { case .none: self = .known(0) case .byteBuffer(let buffer): - self = .known(buffer.readableBytes) + self = .known(Int64(buffer.readableBytes)) case .sequence(let length, _, _), .asyncSequence(let length, _): self = length } @@ -94,5 +141,3 @@ extension HTTPClientRequest { return newRequest } } - -#endif diff --git a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest+auth.swift b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest+auth.swift new file mode 100644 index 000000000..106a8f76b --- /dev/null +++ b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest+auth.swift @@ -0,0 +1,27 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2024 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Foundation + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension HTTPClientRequest { + /// Set basic auth for a request. + /// + /// - parameters: + /// - username: the username to authenticate with + /// - password: authentication password associated with the username + public mutating func setBasicAuth(username: String, password: String) { + self.headers.setBasicAuth(username: username, password: password) + } +} diff --git a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest.swift b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest.swift index cfab828a0..dca7de0ef 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest.swift @@ -12,33 +12,85 @@ // //===----------------------------------------------------------------------===// -#if compiler(>=5.5.2) && canImport(_Concurrency) +import Algorithms import NIOCore import NIOHTTP1 +import NIOSSL +@usableFromInline +let bagOfBytesToByteBufferConversionChunkSize = 1024 * 1024 * 4 + +#if arch(arm) || arch(i386) +// on 32-bit platforms we can't make use of a whole UInt32.max (as it doesn't fit in an Int) +@usableFromInline +let byteBufferMaxSize = Int.max +#else +// on 64-bit platforms we're good +@usableFromInline +let byteBufferMaxSize = Int(UInt32.max) +#endif + +/// A representation of an HTTP request for the Swift Concurrency HTTPClient API. +/// +/// This object is similar to ``HTTPClient/Request``, but used for the Swift Concurrency API. +/// +/// - note: For many ``HTTPClientRequest/body-swift.property`` configurations, this type is _not_ a value type +/// (https://github.com/swift-server/async-http-client/issues/708). @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -public struct HTTPClientRequest { +public struct HTTPClientRequest: Sendable { + /// The request URL, including scheme, hostname, and optionally port. public var url: String + + /// The request method. public var method: HTTPMethod + + /// The request headers. public var headers: HTTPHeaders + /// The request body, if any. public var body: Body? + /// Request-specific TLS configuration, defaults to no request-specific TLS configuration. + public var tlsConfiguration: TLSConfiguration? + public init(url: String) { self.url = url self.method = .GET self.headers = .init() self.body = .none + self.tlsConfiguration = nil } } @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension HTTPClientRequest { - public struct Body { + /// An HTTP request body. + /// + /// This object encapsulates the difference between streamed HTTP request bodies and those bodies that + /// are already entirely in memory. + public struct Body: Sendable { @usableFromInline - internal enum Mode { - case asyncSequence(length: RequestBodyLength, (ByteBufferAllocator) async throws -> ByteBuffer?) - case sequence(length: RequestBodyLength, canBeConsumedMultipleTimes: Bool, (ByteBufferAllocator) -> ByteBuffer) + internal enum Mode: Sendable { + /// - parameters: + /// - length: complete body length. + /// If `length` is `.known`, `nextBodyPart` is not allowed to produce more bytes than `length` defines. + /// - makeAsyncIterator: Creates a new async iterator under the hood and returns a function which will call `next()` on it. + /// The returned function then produce the next body buffer asynchronously. + /// We use a closure as an abstraction instead of an existential to enable specialization. + case asyncSequence( + length: RequestBodyLength, + makeAsyncIterator: @Sendable () -> ((ByteBufferAllocator) async throws -> ByteBuffer?) + ) + /// - parameters: + /// - length: complete body length. + /// If `length` is `.known`, `nextBodyPart` is not allowed to produce more bytes than `length` defines. + /// - canBeConsumedMultipleTimes: if `makeBody` can be called multiple times and returns the same result. + /// - makeCompleteBody: function to produce the complete body. + case sequence( + length: RequestBodyLength, + canBeConsumedMultipleTimes: Bool, + makeCompleteBody: @Sendable (ByteBufferAllocator) -> ByteBuffer + ) case byteBuffer(ByteBuffer) } @@ -54,91 +106,233 @@ extension HTTPClientRequest { @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension HTTPClientRequest.Body { + /// Create an ``HTTPClientRequest/Body-swift.struct`` from a `ByteBuffer`. + /// + /// - parameter byteBuffer: The bytes of the body. public static func bytes(_ byteBuffer: ByteBuffer) -> Self { self.init(.byteBuffer(byteBuffer)) } + /// Create an ``HTTPClientRequest/Body-swift.struct`` from a `RandomAccessCollection` of bytes. + /// + /// This construction will flatten the `bytes` into a `ByteBuffer` in chunks of ~4MB. + /// As a result, the peak memory usage of this construction will be a small multiple of ~4MB. + /// The construction of the `ByteBuffer` will be delayed until it's needed. + /// + /// - parameter bytes: The bytes of the request body. @inlinable - public static func bytes( + @preconcurrency + public static func bytes( _ bytes: Bytes ) -> Self where Bytes.Element == UInt8 { - self.init(.sequence( - length: .known(bytes.count), - canBeConsumedMultipleTimes: true - ) { allocator in - if let buffer = bytes.withContiguousStorageIfAvailable({ allocator.buffer(bytes: $0) }) { - // fastpath - return buffer - } - // potentially really slow path - return allocator.buffer(bytes: bytes) - }) + self.bytes(bytes, length: .known(Int64(bytes.count))) } + /// Create an ``HTTPClientRequest/Body-swift.struct`` from a `Sequence` of bytes. + /// + /// This construction will flatten the bytes into a `ByteBuffer`. As a result, the peak memory + /// usage of this construction will be double the size of the original collection. The construction + /// of the `ByteBuffer` will be delayed until it's needed. + /// + /// Unlike ``bytes(_:)-1uns7``, this construction does not assume that the body can be replayed. As a result, + /// if a redirect is encountered that would need us to replay the request body, the redirect will instead + /// not be followed. Prefer ``bytes(_:)-1uns7`` wherever possible. + /// + /// Caution should be taken with this method to ensure that the `length` is correct. Incorrect lengths + /// will cause unnecessary runtime failures. Setting `length` to ``Length/unknown`` will trigger the upload + /// to use `chunked` `Transfer-Encoding`, while using ``Length/known(_:)-9q0ge`` will use `Content-Length`. + /// + /// - parameters: + /// - bytes: The bytes of the request body. + /// - length: The length of the request body. @inlinable - public static func bytes( + @preconcurrency + public static func bytes( _ bytes: Bytes, length: Length ) -> Self where Bytes.Element == UInt8 { - self.init(.sequence( - length: length.storage, - canBeConsumedMultipleTimes: false - ) { allocator in - if let buffer = bytes.withContiguousStorageIfAvailable({ allocator.buffer(bytes: $0) }) { - // fastpath - return buffer + Self._bytes( + bytes, + length: length, + bagOfBytesToByteBufferConversionChunkSize: bagOfBytesToByteBufferConversionChunkSize, + byteBufferMaxSize: byteBufferMaxSize + ) + } + + /// internal method to test chunking + @inlinable + @preconcurrency + static func _bytes( + _ bytes: Bytes, + length: Length, + bagOfBytesToByteBufferConversionChunkSize: Int, + byteBufferMaxSize: Int + ) -> Self where Bytes.Element == UInt8 { + // fast path + let body: Self? = bytes.withContiguousStorageIfAvailable { bufferPointer -> Self in + // `some Sequence` is special as it can't be efficiently chunked lazily. + // Therefore we need to do the chunking eagerly if it implements the fast path withContiguousStorageIfAvailable + // If we do it eagerly, it doesn't make sense to do a bunch of small chunks, so we only chunk if it exceeds + // the maximum size of a ByteBuffer. + if bufferPointer.count <= byteBufferMaxSize { + let buffer = ByteBuffer(bytes: bufferPointer) + return Self( + .sequence( + length: length.storage, + canBeConsumedMultipleTimes: true, + makeCompleteBody: { _ in buffer } + ) + ) + } else { + // we need to copy `bufferPointer` eagerly as the pointer is only valid during the call to `withContiguousStorageIfAvailable` + let buffers: [ByteBuffer] = bufferPointer.chunks(ofCount: byteBufferMaxSize).map { + ByteBuffer(bytes: $0) + } + return Self( + .asyncSequence( + length: length.storage, + makeAsyncIterator: { + var iterator = buffers.makeIterator() + return { _ in + iterator.next() + } + } + ) + ) } - // potentially really slow path - return allocator.buffer(bytes: bytes) - }) + } + if let body = body { + return body + } + + // slow path + return Self( + .asyncSequence( + length: length.storage + ) { + var iterator = bytes.makeIterator() + return { allocator in + var buffer = allocator.buffer(capacity: bagOfBytesToByteBufferConversionChunkSize) + while buffer.writableBytes > 0, let byte = iterator.next() { + buffer.writeInteger(byte) + } + if buffer.readableBytes > 0 { + return buffer + } + return nil + } + } + ) } + /// Create an ``HTTPClientRequest/Body-swift.struct`` from a `Collection` of bytes. + /// + /// This construction will flatten the `bytes` into a `ByteBuffer` in chunks of ~4MB. + /// As a result, the peak memory usage of this construction will be a small multiple of ~4MB. + /// The construction of the `ByteBuffer` will be delayed until it's needed. + /// + /// Caution should be taken with this method to ensure that the `length` is correct. Incorrect lengths + /// will cause unnecessary runtime failures. Setting `length` to ``Length/unknown`` will trigger the upload + /// to use `chunked` `Transfer-Encoding`, while using ``Length/known(_:)-9q0ge`` will use `Content-Length`. + /// + /// - parameters: + /// - bytes: The bytes of the request body. + /// - length: The length of the request body. @inlinable - public static func bytes( + @preconcurrency + public static func bytes( _ bytes: Bytes, length: Length ) -> Self where Bytes.Element == UInt8 { - self.init(.sequence( - length: length.storage, - canBeConsumedMultipleTimes: true - ) { allocator in - if let buffer = bytes.withContiguousStorageIfAvailable({ allocator.buffer(bytes: $0) }) { - // fastpath - return buffer - } - // potentially really slow path - return allocator.buffer(bytes: bytes) - }) + if bytes.count <= bagOfBytesToByteBufferConversionChunkSize { + return self.init( + .sequence( + length: length.storage, + canBeConsumedMultipleTimes: true + ) { allocator in + allocator.buffer(bytes: bytes) + } + ) + } else { + return self.init( + .asyncSequence( + length: length.storage, + makeAsyncIterator: { + var iterator = bytes.chunks(ofCount: bagOfBytesToByteBufferConversionChunkSize).makeIterator() + return { allocator in + guard let chunk = iterator.next() else { + return nil + } + return allocator.buffer(bytes: chunk) + } + } + ) + ) + } } + /// Create an ``HTTPClientRequest/Body-swift.struct`` from an `AsyncSequence` of `ByteBuffer`s. + /// + /// This construction will stream the upload one `ByteBuffer` at a time. + /// + /// Caution should be taken with this method to ensure that the `length` is correct. Incorrect lengths + /// will cause unnecessary runtime failures. Setting `length` to ``Length/unknown`` will trigger the upload + /// to use `chunked` `Transfer-Encoding`, while using ``Length/known(_:)-9q0ge`` will use `Content-Length`. + /// + /// - parameters: + /// - sequenceOfBytes: The bytes of the request body. + /// - length: The length of the request body. @inlinable - public static func stream( + @preconcurrency + public static func stream( _ sequenceOfBytes: SequenceOfBytes, length: Length ) -> Self where SequenceOfBytes.Element == ByteBuffer { - var iterator = sequenceOfBytes.makeAsyncIterator() - let body = self.init(.asyncSequence(length: length.storage) { _ -> ByteBuffer? in - try await iterator.next() - }) + let body = self.init( + .asyncSequence(length: length.storage) { + var iterator = sequenceOfBytes.makeAsyncIterator() + return { _ -> ByteBuffer? in + try await iterator.next() + } + } + ) return body } + /// Create an ``HTTPClientRequest/Body-swift.struct`` from an `AsyncSequence` of bytes. + /// + /// This construction will consume 4MB chunks from the `Bytes` and send them at once. This optimizes for + /// `AsyncSequence`s where larger chunks are buffered up and available without actually suspending, such + /// as those provided by `FileHandle`. + /// + /// Caution should be taken with this method to ensure that the `length` is correct. Incorrect lengths + /// will cause unnecessary runtime failures. Setting `length` to ``Length/unknown`` will trigger the upload + /// to use `chunked` `Transfer-Encoding`, while using ``Length/known(_:)-9q0ge`` will use `Content-Length`. + /// + /// - parameters: + /// - bytes: The bytes of the request body. + /// - length: The length of the request body. @inlinable - public static func stream( + @preconcurrency + public static func stream( _ bytes: Bytes, length: Length ) -> Self where Bytes.Element == UInt8 { - var iterator = bytes.makeAsyncIterator() - let body = self.init(.asyncSequence(length: length.storage) { allocator -> ByteBuffer? in - var buffer = allocator.buffer(capacity: 1024) // TODO: Magic number - while buffer.writableBytes > 0, let byte = try await iterator.next() { - buffer.writeInteger(byte) - } - if buffer.readableBytes > 0 { - return buffer + let body = self.init( + .asyncSequence(length: length.storage) { + var iterator = bytes.makeAsyncIterator() + return { allocator -> ByteBuffer? in + var buffer = allocator.buffer(capacity: bagOfBytesToByteBufferConversionChunkSize) + while buffer.writableBytes > 0, let byte = try await iterator.next() { + buffer.writeInteger(byte) + } + if buffer.readableBytes > 0 { + return buffer + } + return nil + } } - return nil - }) + ) return body } } @@ -157,11 +351,19 @@ extension Optional where Wrapped == HTTPClientRequest.Body { @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension HTTPClientRequest.Body { - public struct Length { - /// size of the request body is not known before starting the request + /// The length of a HTTP request body. + public struct Length: Sendable { + /// The size of the request body is not known before starting the request public static let unknown: Self = .init(storage: .unknown) - /// size of the request body is fixed and exactly `count` bytes + + /// The size of the request body is known and exactly `count` bytes + @available(*, deprecated, message: "Use `known(_ count: Int64)` with an explicit Int64 argument instead") public static func known(_ count: Int) -> Self { + .init(storage: .known(Int64(count))) + } + + /// The size of the request body is known and exactly `count` bytes + public static func known(_ count: Int64) -> Self { .init(storage: .known(count)) } @@ -170,4 +372,58 @@ extension HTTPClientRequest.Body { } } -#endif +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension HTTPClientRequest.Body: AsyncSequence { + public typealias Element = ByteBuffer + + @inlinable + public func makeAsyncIterator() -> AsyncIterator { + switch self.mode { + case .asyncSequence(_, let makeAsyncIterator): + return .init(storage: .makeNext(makeAsyncIterator())) + case .sequence(_, _, let makeCompleteBody): + return .init(storage: .byteBuffer(makeCompleteBody(AsyncIterator.allocator))) + case .byteBuffer(let byteBuffer): + return .init(storage: .byteBuffer(byteBuffer)) + } + } +} + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension HTTPClientRequest.Body { + public struct AsyncIterator: AsyncIteratorProtocol { + @usableFromInline + static let allocator = ByteBufferAllocator() + + @usableFromInline + enum Storage { + case byteBuffer(ByteBuffer?) + case makeNext((ByteBufferAllocator) async throws -> ByteBuffer?) + } + + @usableFromInline + var storage: Storage + + @inlinable + init(storage: Storage) { + self.storage = storage + } + + @inlinable + public mutating func next() async throws -> ByteBuffer? { + switch self.storage { + case .byteBuffer(let buffer): + self.storage = .byteBuffer(nil) + return buffer + case .makeNext(let makeNext): + return try await makeNext(Self.allocator) + } + } + } +} + +@available(*, unavailable) +extension HTTPClientRequest.Body.AsyncIterator: Sendable {} + +@available(*, unavailable) +extension HTTPClientRequest.Body.AsyncIterator.Storage: Sendable {} diff --git a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientResponse.swift b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientResponse.swift index 52f03089b..36c1cb36f 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientResponse.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientResponse.swift @@ -12,105 +12,256 @@ // //===----------------------------------------------------------------------===// -#if compiler(>=5.5.2) && canImport(_Concurrency) import NIOCore import NIOHTTP1 +import struct Foundation.URL + +/// A representation of an HTTP response for the Swift Concurrency HTTPClient API. +/// +/// This object is similar to ``HTTPClient/Response``, but used for the Swift Concurrency API. @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -public struct HTTPClientResponse { +public struct HTTPClientResponse: Sendable { + /// The HTTP version on which the response was received. public var version: HTTPVersion + + /// The HTTP status for this response. public var status: HTTPResponseStatus + + /// The HTTP headers of this response. public var headers: HTTPHeaders + + /// The body of this HTTP response. public var body: Body - public struct Body { - private let bag: Transaction - private let reference: ResponseRef + /// The history of all requests and responses in redirect order. + public var history: [HTTPClientRequestResponse] - fileprivate init(_ transaction: Transaction) { - self.bag = transaction - self.reference = ResponseRef(transaction: transaction) + /// The target URL (after redirects) of the response. + public var url: URL? { + guard let lastRequestURL = self.history.last?.request.url else { + return nil } + + return URL(string: lastRequestURL) } - init( - bag: Transaction, - version: HTTPVersion, - status: HTTPResponseStatus, - headers: HTTPHeaders + @inlinable public init( + version: HTTPVersion = .http1_1, + status: HTTPResponseStatus = .ok, + headers: HTTPHeaders = [:], + body: Body = Body() ) { - self.body = Body(bag) self.version = version self.status = status self.headers = headers + self.body = body + self.history = [] + } + + @inlinable public init( + version: HTTPVersion = .http1_1, + status: HTTPResponseStatus = .ok, + headers: HTTPHeaders = [:], + body: Body = Body(), + history: [HTTPClientRequestResponse] = [] + ) { + self.version = version + self.status = status + self.headers = headers + self.body = body + self.history = history + } + + init( + requestMethod: HTTPMethod, + version: HTTPVersion, + status: HTTPResponseStatus, + headers: HTTPHeaders, + body: TransactionBody, + history: [HTTPClientRequestResponse] + ) { + self.init( + version: version, + status: status, + headers: headers, + body: .init( + .transaction( + body, + expectedContentLength: HTTPClientResponse.expectedContentLength( + requestMethod: requestMethod, + headers: headers, + status: status + ) + ) + ), + history: history + ) } } @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -extension HTTPClientResponse.Body: AsyncSequence { - public typealias Element = AsyncIterator.Element +public struct HTTPClientRequestResponse: Sendable { + public var request: HTTPClientRequest + public var responseHead: HTTPResponseHead - public struct AsyncIterator: AsyncIteratorProtocol { - private let stream: IteratorStream + public init(request: HTTPClientRequest, responseHead: HTTPResponseHead) { + self.request = request + self.responseHead = responseHead + } +} + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension HTTPClientResponse { + /// A representation of the response body for an HTTP response. + /// + /// The body is streamed as an `AsyncSequence` of `ByteBuffer`, where each `ByteBuffer` contains + /// an arbitrarily large chunk of data. The boundaries between `ByteBuffer` objects in the sequence + /// are entirely synthetic and have no semantic meaning. + public struct Body: AsyncSequence, Sendable { + public typealias Element = ByteBuffer + public struct AsyncIterator: AsyncIteratorProtocol { + @usableFromInline var storage: Storage.AsyncIterator - fileprivate init(stream: IteratorStream) { - self.stream = stream + @inlinable init(storage: Storage.AsyncIterator) { + self.storage = storage + } + + @inlinable public mutating func next() async throws -> ByteBuffer? { + try await self.storage.next() + } } - public mutating func next() async throws -> ByteBuffer? { - try await self.stream.next() + @usableFromInline var storage: Storage + + @inlinable public func makeAsyncIterator() -> AsyncIterator { + .init(storage: self.storage.makeAsyncIterator()) + } + + @inlinable init(storage: Storage) { + self.storage = storage + } + + /// Accumulates `Body` of `ByteBuffer`s into a single `ByteBuffer`. + /// - Parameters: + /// - maxBytes: The maximum number of bytes this method is allowed to accumulate + /// - Throws: `NIOTooManyBytesError` if the the sequence contains more than `maxBytes`. + /// - Returns: the number of bytes collected over time + @inlinable public func collect(upTo maxBytes: Int) async throws -> ByteBuffer { + switch self.storage { + case .transaction(_, let expectedContentLength): + if let contentLength = expectedContentLength { + if contentLength > maxBytes { + throw NIOTooManyBytesError(maxBytes: maxBytes) + } + } + case .anyAsyncSequence: + break + } + + /// calling collect function within here in order to ensure the correct nested type + func collect(_ body: Body, maxBytes: Int) async throws -> ByteBuffer + where Body.Element == ByteBuffer { + try await body.collect(upTo: maxBytes) + } + return try await collect(self, maxBytes: maxBytes) } } +} - public func makeAsyncIterator() -> AsyncIterator { - AsyncIterator(stream: IteratorStream(bag: self.bag)) +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension HTTPClientResponse { + static func expectedContentLength( + requestMethod: HTTPMethod, + headers: HTTPHeaders, + status: HTTPResponseStatus + ) -> Int? { + if status == .notModified { + return 0 + } else if requestMethod == .HEAD { + return 0 + } else { + let contentLength = headers["content-length"].first.flatMap { Int($0, radix: 10) } + return contentLength + } } } +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +@usableFromInline +typealias TransactionBody = NIOThrowingAsyncSequenceProducer< + ByteBuffer, + Error, + NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark, + AnyAsyncSequenceProducerDelegate +> + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension HTTPClientResponse.Body { - /// The purpose of this object is to inform the transaction about the response body being deinitialized. - /// If the users has not called `makeAsyncIterator` on the body, before it is deinited, the http - /// request needs to be cancelled. - fileprivate class ResponseRef { - private let transaction: Transaction - - init(transaction: Transaction) { - self.transaction = transaction - } + @usableFromInline enum Storage: Sendable { + case transaction(TransactionBody, expectedContentLength: Int?) + case anyAsyncSequence(AnyAsyncSequence) + } +} - deinit { - self.transaction.responseBodyDeinited() +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension HTTPClientResponse.Body.Storage: AsyncSequence { + @usableFromInline typealias Element = ByteBuffer + + @inlinable func makeAsyncIterator() -> AsyncIterator { + switch self { + case .transaction(let transaction, _): + return .transaction(transaction.makeAsyncIterator()) + case .anyAsyncSequence(let anyAsyncSequence): + return .anyAsyncSequence(anyAsyncSequence.makeAsyncIterator()) } } } @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -extension HTTPClientResponse.Body { - internal class IteratorStream { - struct ID: Hashable { - private let objectID: ObjectIdentifier +extension HTTPClientResponse.Body.Storage { + @usableFromInline enum AsyncIterator { + case transaction(TransactionBody.AsyncIterator) + case anyAsyncSequence(AnyAsyncSequence.AsyncIterator) + } +} - init(_ object: IteratorStream) { - self.objectID = ObjectIdentifier(object) - } +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension HTTPClientResponse.Body.Storage.AsyncIterator: AsyncIteratorProtocol { + @inlinable mutating func next() async throws -> ByteBuffer? { + switch self { + case .transaction(let iterator): + return try await iterator.next() + case .anyAsyncSequence(var iterator): + defer { self = .anyAsyncSequence(iterator) } + return try await iterator.next() } + } +} - private var id: ID { ID(self) } - private let bag: Transaction +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension HTTPClientResponse.Body { + @inlinable init(_ storage: Storage) { + self.storage = storage + } - init(bag: Transaction) { - self.bag = bag - } + public init() { + self = .stream(EmptyCollection().async) + } - deinit { - self.bag.responseBodyIteratorDeinited(streamID: self.id) - } + @inlinable public static func stream( + _ sequenceOfBytes: SequenceOfBytes + ) -> Self where SequenceOfBytes: AsyncSequence & Sendable, SequenceOfBytes.Element == ByteBuffer { + Self(storage: .anyAsyncSequence(AnyAsyncSequence(sequenceOfBytes.singleIteratorPrecondition))) + } - func next() async throws -> ByteBuffer? { - try await self.bag.nextResponsePart(streamID: self.id) - } + public static func bytes(_ byteBuffer: ByteBuffer) -> Self { + .stream(CollectionOfOne(byteBuffer).async) } } -#endif +@available(*, unavailable) +extension HTTPClientResponse.Body.AsyncIterator: Sendable {} + +@available(*, unavailable) +extension HTTPClientResponse.Body.Storage.AsyncIterator: Sendable {} diff --git a/Sources/AsyncHTTPClient/AsyncAwait/SingleIteratorPrecondition.swift b/Sources/AsyncHTTPClient/AsyncAwait/SingleIteratorPrecondition.swift new file mode 100644 index 000000000..04034db2d --- /dev/null +++ b/Sources/AsyncHTTPClient/AsyncAwait/SingleIteratorPrecondition.swift @@ -0,0 +1,45 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2022 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Atomics + +/// Makes sure that a consumer of this `AsyncSequence` only calls `makeAsyncIterator()` at most once. +/// If `makeAsyncIterator()` is called multiple times, the program crashes. +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +@usableFromInline struct SingleIteratorPrecondition: AsyncSequence { + @usableFromInline let base: Base + @usableFromInline let didCreateIterator: ManagedAtomic = .init(false) + @usableFromInline typealias Element = Base.Element + @inlinable init(base: Base) { + self.base = base + } + + @inlinable func makeAsyncIterator() -> Base.AsyncIterator { + precondition( + self.didCreateIterator.exchange(true, ordering: .relaxed) == false, + "makeAsyncIterator() is only allowed to be called at most once." + ) + return self.base.makeAsyncIterator() + } +} + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension SingleIteratorPrecondition: @unchecked Sendable where Base: Sendable {} + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension AsyncSequence { + @inlinable var singleIteratorPrecondition: SingleIteratorPrecondition { + .init(base: self) + } +} diff --git a/Sources/AsyncHTTPClient/AsyncAwait/Transaction+StateMachine.swift b/Sources/AsyncHTTPClient/AsyncAwait/Transaction+StateMachine.swift index dea1093db..457627a8a 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/Transaction+StateMachine.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/Transaction+StateMachine.swift @@ -11,13 +11,14 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -#if compiler(>=5.5.2) && canImport(_Concurrency) + import Logging import NIOCore import NIOHTTP1 @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension Transaction { + @usableFromInline struct StateMachine { struct ExecutionContext { let executor: HTTPRequestExecutor @@ -28,32 +29,24 @@ extension Transaction { private enum State { case initialized(CheckedContinuation) case queued(CheckedContinuation, HTTPRequestScheduler) + case deadlineExceededWhileQueued(CheckedContinuation) case executing(ExecutionContext, RequestStreamState, ResponseStreamState) - case finished(error: Error?, HTTPClientResponse.Body.IteratorStream.ID?) + case finished(error: Error?) } - fileprivate enum RequestStreamState { + fileprivate enum RequestStreamState: Sendable { case requestHeadSent case producing case paused(continuation: CheckedContinuation?) case finished } - fileprivate enum ResponseStreamState { - enum Next { - case askExecutorForMore - case error(Error) - case endOfFile - } - - // Waiting for response head. Valid transitions to: waitingForStream. + fileprivate enum ResponseStreamState: Sendable { + // Waiting for response head. Valid transitions to: streamingBody. case waitingForResponseHead - // We are waiting for the user to create a response body iterator and to call next on - // it for the first time. - case waitingForResponseIterator(CircularBuffer, next: Next) - case buffering(HTTPClientResponse.Body.IteratorStream.ID, CircularBuffer, next: Next) - case waitingForRemote(HTTPClientResponse.Body.IteratorStream.ID, CheckedContinuation) - case finished(HTTPClientResponse.Body.IteratorStream.ID, CheckedContinuation) + // streaming response body. Valid transitions to: finished. + case streamingBody(TransactionBody.Source) + case finished } private var state: State @@ -89,9 +82,20 @@ extension Transaction { enum FailAction { case none /// fail response before head received. scheduler and executor are exclusive here. - case failResponseHead(CheckedContinuation, Error, HTTPRequestScheduler?, HTTPRequestExecutor?, bodyStreamContinuation: CheckedContinuation?) + case failResponseHead( + CheckedContinuation, + Error, + HTTPRequestScheduler?, + HTTPRequestExecutor?, + bodyStreamContinuation: CheckedContinuation? + ) /// fail response after response head received. fail the response stream (aka call to `next()`) - case failResponseStream(CheckedContinuation, Error, HTTPRequestExecutor, bodyStreamContinuation: CheckedContinuation?) + case failResponseStream( + TransactionBody.Source, + Error, + HTTPRequestExecutor, + bodyStreamContinuation: CheckedContinuation? + ) case failRequestStreamContinuation(CheckedContinuation, Error) } @@ -99,78 +103,66 @@ extension Transaction { mutating func fail(_ error: Error) -> FailAction { switch self.state { case .initialized(let continuation): - self.state = .finished(error: error, nil) + self.state = .finished(error: error) return .failResponseHead(continuation, error, nil, nil, bodyStreamContinuation: nil) case .queued(let continuation, let scheduler): - self.state = .finished(error: error, nil) + self.state = .finished(error: error) return .failResponseHead(continuation, error, scheduler, nil, bodyStreamContinuation: nil) - + case .deadlineExceededWhileQueued(let continuation): + let realError: Error = { + if (error as? HTTPClientError) == .cancelled { + /// if we just get a `HTTPClientError.cancelled` we can use the original cancellation reason + /// to give a more descriptive error to the user. + return HTTPClientError.deadlineExceeded + } else { + /// otherwise we already had an intermediate connection error which we should present to the user instead + return error + } + }() + + self.state = .finished(error: realError) + return .failResponseHead(continuation, realError, nil, nil, bodyStreamContinuation: nil) case .executing(let context, let requestStreamState, .waitingForResponseHead): switch requestStreamState { case .paused(continuation: .some(let continuation)): - self.state = .finished(error: error, nil) - return .failResponseHead(context.continuation, error, nil, context.executor, bodyStreamContinuation: continuation) + self.state = .finished(error: error) + return .failResponseHead( + context.continuation, + error, + nil, + context.executor, + bodyStreamContinuation: continuation + ) case .requestHeadSent, .finished, .producing, .paused(continuation: .none): - self.state = .finished(error: error, nil) - return .failResponseHead(context.continuation, error, nil, context.executor, bodyStreamContinuation: nil) - } - - case .executing(let context, let requestStreamState, .waitingForResponseIterator(let buffer, next: .askExecutorForMore)), - .executing(let context, let requestStreamState, .waitingForResponseIterator(let buffer, next: .endOfFile)): - switch requestStreamState { - case .paused(.some(let continuation)): - self.state = .executing(context, .finished, .waitingForResponseIterator(buffer, next: .error(error))) - return .failRequestStreamContinuation(continuation, error) - - case .requestHeadSent, .producing, .paused(continuation: .none), .finished: - self.state = .executing(context, .finished, .waitingForResponseIterator(buffer, next: .error(error))) - return .none - } - - case .executing(let context, let requestStreamState, .buffering(let streamID, let buffer, next: .askExecutorForMore)), - .executing(let context, let requestStreamState, .buffering(let streamID, let buffer, next: .endOfFile)): - switch requestStreamState { - case .paused(continuation: .some(let continuation)): - self.state = .executing(context, .finished, .buffering(streamID, buffer, next: .error(error))) - return .failRequestStreamContinuation(continuation, error) - - case .requestHeadSent, .paused(continuation: .none), .producing, .finished: - self.state = .executing(context, .finished, .buffering(streamID, buffer, next: .error(error))) - return .none + self.state = .finished(error: error) + return .failResponseHead( + context.continuation, + error, + nil, + context.executor, + bodyStreamContinuation: nil + ) } - case .executing(let context, let requestStreamState, .waitingForRemote(let streamID, let continuation)): - // We are in response streaming. The response stream is waiting for the next bytes - // from the server. We can fail the call to `next` immediately. + case .executing(let context, let requestStreamState, .streamingBody(let source)): + self.state = .finished(error: error) switch requestStreamState { - case .paused(continuation: .some(let bodyStreamContinuation)): - self.state = .finished(error: error, streamID) - return .failResponseStream(continuation, error, context.executor, bodyStreamContinuation: bodyStreamContinuation) - - case .requestHeadSent, .paused(continuation: .none), .producing, .finished: - self.state = .finished(error: error, streamID) - return .failResponseStream(continuation, error, context.executor, bodyStreamContinuation: nil) + case .paused(let bodyStreamContinuation): + return .failResponseStream( + source, + error, + context.executor, + bodyStreamContinuation: bodyStreamContinuation + ) + case .finished, .producing, .requestHeadSent: + return .failResponseStream(source, error, context.executor, bodyStreamContinuation: nil) } - case .finished(error: _, _), - .executing(_, _, .waitingForResponseIterator(_, next: .error)), - .executing(_, _, .buffering(_, _, next: .error)): - // The request has already failed, succeeded, or the users is not interested in the - // response. There is no more way to reach the user code. Just drop the error. + case .finished(error: _), + .executing(_, _, .finished): return .none - - case .executing(let context, let requestStreamState, .finished(let streamID, let continuation)): - switch requestStreamState { - case .paused(continuation: .some(let bodyStreamContinuation)): - self.state = .finished(error: error, streamID) - return .failResponseStream(continuation, error, context.executor, bodyStreamContinuation: bodyStreamContinuation) - - case .requestHeadSent, .paused(continuation: .none), .producing, .finished: - self.state = .finished(error: error, streamID) - return .failResponseStream(continuation, error, context.executor, bodyStreamContinuation: nil) - } } } @@ -178,6 +170,7 @@ extension Transaction { enum StartExecutionAction { case cancel(HTTPRequestExecutor) + case cancelAndFail(HTTPRequestExecutor, CheckedContinuation, with: Error) case none } @@ -191,13 +184,16 @@ extension Transaction { ) self.state = .executing(context, .requestHeadSent, .waitingForResponseHead) return .none + case .deadlineExceededWhileQueued(let continuation): + let error = HTTPClientError.deadlineExceeded + self.state = .finished(error: error) + return .cancelAndFail(executor, continuation, with: error) - case .finished(error: .some, .none): + case .finished(error: .some): return .cancel(executor) case .executing, - .finished(error: .none, _), - .finished(error: .some, .some): + .finished(error: .none): preconditionFailure("Invalid state: \(self.state)") } } @@ -210,8 +206,10 @@ extension Transaction { mutating func resumeRequestBodyStream() -> ResumeProducingAction { switch self.state { - case .initialized, .queued: - preconditionFailure("Received a resumeBodyRequest on a request, that isn't executing. Invalid state: \(self.state)") + case .initialized, .queued, .deadlineExceededWhileQueued: + preconditionFailure( + "Received a resumeBodyRequest on a request, that isn't executing. Invalid state: \(self.state)" + ) case .executing(let context, .requestHeadSent, let responseState): // the request can start to send its body. @@ -219,7 +217,9 @@ extension Transaction { return .startStream(context.allocator) case .executing(_, .producing, _): - preconditionFailure("Received a resumeBodyRequest on a request, that is producing. Invalid state: \(self.state)") + preconditionFailure( + "Received a resumeBodyRequest on a request, that is producing. Invalid state: \(self.state)" + ) case .executing(let context, .paused(.none), let responseState): // request stream is currently paused, but there is no write waiting. We don't need @@ -245,16 +245,17 @@ extension Transaction { mutating func pauseRequestBodyStream() { switch self.state { case .initialized, - .queued, - .executing(_, .requestHeadSent, _): + .queued, + .deadlineExceededWhileQueued, + .executing(_, .requestHeadSent, _): preconditionFailure("A request stream can only be resumed, if the request was started") case .executing(let context, .producing, let responseSteam): self.state = .executing(context, .paused(continuation: nil), responseSteam) case .executing(_, .paused, _), - .executing(_, .finished, _), - .finished: + .executing(_, .finished, _), + .finished: // the channels writability changed to paused after we have already forwarded all // request bytes. Can be ignored. break @@ -270,9 +271,12 @@ extension Transaction { func writeNextRequestPart() -> NextWriteAction { switch self.state { case .initialized, - .queued, - .executing(_, .requestHeadSent, _): - preconditionFailure("A request stream can only produce, if the request was started. Invalid state: \(self.state)") + .queued, + .deadlineExceededWhileQueued, + .executing(_, .requestHeadSent, _): + preconditionFailure( + "A request stream can only produce, if the request was started. Invalid state: \(self.state)" + ) case .executing(let context, .producing, _): // We are currently producing the request body. The executors channel is writable. @@ -290,7 +294,9 @@ extension Transaction { return .writeAndWait(context.executor) case .executing(_, .paused(continuation: .some), _): - preconditionFailure("A write continuation already exists, but we tried to set another one. Invalid state: \(self.state)") + preconditionFailure( + "A write continuation already exists, but we tried to set another one. Invalid state: \(self.state)" + ) case .finished, .executing(_, .finished, _): return .fail @@ -300,10 +306,13 @@ extension Transaction { mutating func waitForRequestBodyDemand(continuation: CheckedContinuation) { switch self.state { case .initialized, - .queued, - .executing(_, .requestHeadSent, _), - .executing(_, .finished, _): - preconditionFailure("A request stream can only produce, if the request was started. Invalid state: \(self.state)") + .queued, + .deadlineExceededWhileQueued, + .executing(_, .requestHeadSent, _), + .executing(_, .finished, _): + preconditionFailure( + "A request stream can only produce, if the request was started. Invalid state: \(self.state)" + ) case .executing(_, .producing, _): preconditionFailure() @@ -324,36 +333,38 @@ extension Transaction { } enum FinishAction { - // forward the notice that the request stream has finished. If finalContinuation is not - // nil, succeed the continuation with nil to signal the requests end. - case forwardStreamFinished(HTTPRequestExecutor, finalContinuation: CheckedContinuation?) + // forward the notice that the request stream has finished. + case forwardStreamFinished(HTTPRequestExecutor) case none } mutating func finishRequestBodyStream() -> FinishAction { switch self.state { case .initialized, - .queued, - .executing(_, .finished, _): + .queued, + .deadlineExceededWhileQueued, + .executing(_, .finished, _): preconditionFailure("Invalid state: \(self.state)") case .executing(_, .paused(continuation: .some), _): - preconditionFailure("Received a request body end, while having a registered back-pressure continuation. Invalid state: \(self.state)") + preconditionFailure( + "Received a request body end, while having a registered back-pressure continuation. Invalid state: \(self.state)" + ) case .executing(let context, .producing, let responseState), - .executing(let context, .paused(continuation: .none), let responseState), - .executing(let context, .requestHeadSent, let responseState): + .executing(let context, .paused(continuation: .none), let responseState), + .executing(let context, .requestHeadSent, let responseState): switch responseState { - case .finished(let registeredStreamID, let continuation): + case .finished: // if the response stream has already finished before the request, we must succeed // the final continuation. - self.state = .finished(error: nil, registeredStreamID) - return .forwardStreamFinished(context.executor, finalContinuation: continuation) + self.state = .finished(error: nil) + return .forwardStreamFinished(context.executor) - case .waitingForResponseHead, .waitingForResponseIterator, .waitingForRemote, .buffering: + case .waitingForResponseHead, .streamingBody: self.state = .executing(context, .finished, responseState) - return .forwardStreamFinished(context.executor, finalContinuation: nil) + return .forwardStreamFinished(context.executor) } case .finished: @@ -364,319 +375,125 @@ extension Transaction { // MARK: - Response - enum ReceiveResponseHeadAction { - case succeedResponseHead(HTTPResponseHead, CheckedContinuation) + case succeedResponseHead(TransactionBody, CheckedContinuation) case none } - mutating func receiveResponseHead(_ head: HTTPResponseHead) -> ReceiveResponseHeadAction { + mutating func receiveResponseHead( + _ head: HTTPResponseHead, + delegate: Delegate + ) -> ReceiveResponseHeadAction { switch self.state { case .initialized, - .queued, - .executing(_, _, .waitingForResponseIterator), - .executing(_, _, .buffering), - .executing(_, _, .waitingForRemote): - preconditionFailure("How can we receive a response, if the request hasn't started yet.") + .queued, + .deadlineExceededWhileQueued, + .executing(_, _, .streamingBody), + .executing(_, _, .finished): + preconditionFailure("invalid state \(self.state)") case .executing(let context, let requestState, .waitingForResponseHead): // The response head was received. Next we will wait for the consumer to create a // response body stream. - self.state = .executing(context, requestState, .waitingForResponseIterator(.init(), next: .askExecutorForMore)) - return .succeedResponseHead(head, context.continuation) + let body = TransactionBody.makeSequence( + backPressureStrategy: .init(lowWatermark: 1, highWatermark: 1), + finishOnDeinit: true, + delegate: AnyAsyncSequenceProducerDelegate(delegate) + ) - case .finished(error: .some, _): + self.state = .executing(context, requestState, .streamingBody(body.source)) + return .succeedResponseHead(body.sequence, context.continuation) + + case .finished(error: .some): // If the request failed before, we don't need to do anything in response to // receiving the response head. return .none - case .executing(_, _, .finished), - .finished(error: .none, _): + case .finished(error: .none): preconditionFailure("How can the request be finished without error, before receiving response head?") } } - enum ReceiveResponsePartAction { + enum ProduceMoreAction { case none - case succeedContinuation(CheckedContinuation, ByteBuffer) + case requestMoreResponseBodyParts(HTTPRequestExecutor) } - mutating func receiveResponseBodyParts(_ buffer: CircularBuffer) -> ReceiveResponsePartAction { + mutating func produceMore() -> ProduceMoreAction { switch self.state { - case .initialized, .queued: - preconditionFailure("Received a response body part, but request hasn't started yet. Invalid state: \(self.state)") - - case .executing(_, _, .waitingForResponseHead): - preconditionFailure("If we receive a response body, we must have received a head before") - - case .executing(let context, let requestState, .buffering(let streamID, var currentBuffer, next: let next)): - guard case .askExecutorForMore = next else { - preconditionFailure("If we have received an error or eof before, why did we get another body part? Next: \(next)") - } - - if currentBuffer.isEmpty { - currentBuffer = buffer - } else { - currentBuffer.append(contentsOf: buffer) - } - self.state = .executing(context, requestState, .buffering(streamID, currentBuffer, next: next)) - return .none - - case .executing(let executor, let requestState, .waitingForResponseIterator(var currentBuffer, next: let next)): - guard case .askExecutorForMore = next else { - preconditionFailure("If we have received an error or eof before, why did we get another body part? Next: \(next)") - } - - if currentBuffer.isEmpty { - currentBuffer = buffer - } else { - currentBuffer.append(contentsOf: buffer) - } - self.state = .executing(executor, requestState, .waitingForResponseIterator(currentBuffer, next: next)) - return .none - - case .executing(let executor, let requestState, .waitingForRemote(let streamID, let continuation)): - var buffer = buffer - let first = buffer.removeFirst() - self.state = .executing(executor, requestState, .buffering(streamID, buffer, next: .askExecutorForMore)) - return .succeedContinuation(continuation, first) - - case .finished: - // the request failed or was cancelled before, we can ignore further data + case .initialized, + .queued, + .deadlineExceededWhileQueued, + .executing(_, _, .waitingForResponseHead): + preconditionFailure("invalid state \(self.state)") + + case .executing(let context, _, .streamingBody): + return .requestMoreResponseBodyParts(context.executor) + case .finished, + .executing(_, _, .finished): return .none - - case .executing(_, _, .finished): - preconditionFailure("Received response end. Must not receive further body parts after that. Invalid state: \(self.state)") } } - enum ResponseBodyDeinitedAction { - case cancel(HTTPRequestExecutor) + enum ReceiveResponsePartAction { case none + case yieldResponseBodyParts(TransactionBody.Source, CircularBuffer, HTTPRequestExecutor) } - mutating func responseBodyDeinited() -> ResponseBodyDeinitedAction { + mutating func receiveResponseBodyParts(_ buffer: CircularBuffer) -> ReceiveResponsePartAction { switch self.state { - case .initialized, - .queued, - .executing(_, _, .waitingForResponseHead): - preconditionFailure("Got notice about a deinited response, before we even received a response. Invalid state: \(self.state)") + case .initialized, .queued, .deadlineExceededWhileQueued: + preconditionFailure( + "Received a response body part, but request hasn't started yet. Invalid state: \(self.state)" + ) - case .executing(_, _, .waitingForResponseIterator(_, next: .endOfFile)): - self.state = .finished(error: nil, nil) - return .none + case .executing(_, _, .waitingForResponseHead): + preconditionFailure("If we receive a response body, we must have received a head before") - case .executing(let context, _, .waitingForResponseIterator(_, next: .askExecutorForMore)): - self.state = .finished(error: nil, nil) - return .cancel(context.executor) - - case .executing(_, _, .waitingForResponseIterator(_, next: .error(let error))): - self.state = .finished(error: error, nil) - return .none + case .executing(let context, _, .streamingBody(let source)): + return .yieldResponseBodyParts(source, buffer, context.executor) case .finished: - // body was released after the response was consumed - return .none - - case .executing(_, _, .buffering), - .executing(_, _, .waitingForRemote), - .executing(_, _, .finished): - // user is consuming the stream with an iterator - return .none - } - } - - mutating func responseBodyIteratorDeinited(streamID: HTTPClientResponse.Body.IteratorStream.ID) -> FailAction { - switch self.state { - case .initialized, .queued, .executing(_, _, .waitingForResponseHead): - preconditionFailure("Got notice about a deinited response body iterator, before we even received a response. Invalid state: \(self.state)") - - case .executing(_, _, .buffering(let registeredStreamID, _, next: _)), - .executing(_, _, .waitingForRemote(let registeredStreamID, _)): - self.verifyStreamIDIsEqual(registered: registeredStreamID, this: streamID) - return self.fail(HTTPClientError.cancelled) - - case .executing(_, _, .waitingForResponseIterator), - .executing(_, _, .finished), - .finished: - // the iterator went out of memory after the request was done. nothing to do. + // the request failed or was cancelled before, we can ignore further data return .none - } - } - - enum ConsumeAction { - case succeedContinuation(CheckedContinuation, ByteBuffer?) - case failContinuation(CheckedContinuation, Error) - case askExecutorForMore(HTTPRequestExecutor) - case none - } - - mutating func consumeNextResponsePart( - streamID: HTTPClientResponse.Body.IteratorStream.ID, - continuation: CheckedContinuation - ) -> ConsumeAction { - switch self.state { - case .initialized, - .queued, - .executing(_, _, .waitingForResponseHead): - preconditionFailure("If we receive a response body, we must have received a head before") case .executing(_, _, .finished): - preconditionFailure("This is an invalid state at this point. We are waiting for the request stream to finish to succeed the response stream. By sending a fi") - - case .executing(let context, let requestState, .waitingForResponseIterator(var buffer, next: .askExecutorForMore)): - if buffer.isEmpty { - self.state = .executing(context, requestState, .waitingForRemote(streamID, continuation)) - return .askExecutorForMore(context.executor) - } else { - let toReturn = buffer.removeFirst() - self.state = .executing(context, requestState, .buffering(streamID, buffer, next: .askExecutorForMore)) - return .succeedContinuation(continuation, toReturn) - } - - case .executing(_, _, .waitingForResponseIterator(_, next: .error(let error))): - self.state = .finished(error: error, streamID) - return .failContinuation(continuation, error) - - case .executing(_, _, .waitingForResponseIterator(let buffer, next: .endOfFile)) where buffer.isEmpty: - self.state = .finished(error: nil, streamID) - return .succeedContinuation(continuation, nil) - - case .executing(let context, let requestState, .waitingForResponseIterator(var buffer, next: .endOfFile)): - assert(!buffer.isEmpty) - let toReturn = buffer.removeFirst() - self.state = .executing(context, requestState, .buffering(streamID, buffer, next: .endOfFile)) - return .succeedContinuation(continuation, toReturn) - - case .executing(let context, let requestState, .buffering(let registeredStreamID, var buffer, next: .askExecutorForMore)): - self.verifyStreamIDIsEqual(registered: registeredStreamID, this: streamID) - - if buffer.isEmpty { - self.state = .executing(context, requestState, .waitingForRemote(streamID, continuation)) - return .askExecutorForMore(context.executor) - } else { - let toReturn = buffer.removeFirst() - self.state = .executing(context, requestState, .buffering(streamID, buffer, next: .askExecutorForMore)) - return .succeedContinuation(continuation, toReturn) - } - - case .executing(_, _, .buffering(let registeredStreamID, _, next: .error(let error))): - self.verifyStreamIDIsEqual(registered: registeredStreamID, this: streamID) - self.state = .finished(error: error, registeredStreamID) - return .failContinuation(continuation, error) - - case .executing(_, _, .buffering(let registeredStreamID, let buffer, next: .endOfFile)) where buffer.isEmpty: - self.verifyStreamIDIsEqual(registered: registeredStreamID, this: streamID) - self.state = .finished(error: nil, registeredStreamID) - return .succeedContinuation(continuation, nil) - - case .executing(let context, let requestState, .buffering(let registeredStreamID, var buffer, next: .endOfFile)): - self.verifyStreamIDIsEqual(registered: registeredStreamID, this: streamID) - if let toReturn = buffer.popFirst() { - // As long as we have bytes in the local store, we can hand them to the user. - self.state = .executing(context, requestState, .buffering(streamID, buffer, next: .endOfFile)) - return .succeedContinuation(continuation, toReturn) - } - - switch requestState { - case .requestHeadSent, .paused, .producing: - // if the request isn't finished yet, we don't succeed the final response stream - // continuation. We will succeed it once the request has been fully send. - self.state = .executing(context, requestState, .finished(streamID, continuation)) - return .none - case .finished: - // if the request is finished, we can succeed the final continuation. - self.state = .finished(error: nil, streamID) - return .succeedContinuation(continuation, nil) - } - - case .executing(_, _, .waitingForRemote(let registeredStreamID, _)): - self.verifyStreamIDIsEqual(registered: registeredStreamID, this: streamID) - preconditionFailure("A body response continuation from this iterator already exists! Queuing calls to `next()` is not supported.") - - case .finished(error: .some(let error), let registeredStreamID): - if let registeredStreamID = registeredStreamID { - self.verifyStreamIDIsEqual(registered: registeredStreamID, this: streamID) - } else { - self.state = .finished(error: error, streamID) - } - return .failContinuation(continuation, error) - - case .finished(error: .none, let registeredStreamID): - if let registeredStreamID = registeredStreamID { - self.verifyStreamIDIsEqual(registered: registeredStreamID, this: streamID) - } else { - self.state = .finished(error: .none, streamID) - } - - return .succeedContinuation(continuation, nil) - } - } - - private func verifyStreamIDIsEqual( - registered: HTTPClientResponse.Body.IteratorStream.ID, - this: HTTPClientResponse.Body.IteratorStream.ID, - file: StaticString = #file, - line: UInt = #line - ) { - if registered != this { preconditionFailure( - "Tried to use a second iterator on response body stream. Multiple iterators are not supported.", - file: file, line: line + "Received response end. Must not receive further body parts after that. Invalid state: \(self.state)" ) } } enum ReceiveResponseEndAction { - case succeedContinuation(CheckedContinuation, ByteBuffer) - case finishResponseStream(CheckedContinuation) + case finishResponseStream(TransactionBody.Source, finalBody: CircularBuffer?) case none } mutating func succeedRequest(_ newChunks: CircularBuffer?) -> ReceiveResponseEndAction { switch self.state { case .initialized, - .queued, - .executing(_, _, .waitingForResponseHead): - preconditionFailure("Received no response head, but received a response end. Invalid state: \(self.state)") - - case .executing(let context, let requestState, .waitingForResponseIterator(var buffer, next: .askExecutorForMore)): - if let newChunks = newChunks, !newChunks.isEmpty { - buffer.append(contentsOf: newChunks) - } - self.state = .executing(context, requestState, .waitingForResponseIterator(buffer, next: .endOfFile)) - return .none - - case .executing(let context, let requestState, .waitingForRemote(let streamID, let continuation)): - if var newChunks = newChunks, !newChunks.isEmpty { - let first = newChunks.removeFirst() - self.state = .executing(context, requestState, .buffering(streamID, newChunks, next: .endOfFile)) - return .succeedContinuation(continuation, first) - } - - self.state = .finished(error: nil, streamID) - return .finishResponseStream(continuation) - - case .executing(let context, let requestState, .buffering(let streamID, var buffer, next: .askExecutorForMore)): - if let newChunks = newChunks, !newChunks.isEmpty { - buffer.append(contentsOf: newChunks) - } - self.state = .executing(context, requestState, .buffering(streamID, buffer, next: .endOfFile)) - return .none + .queued, + .deadlineExceededWhileQueued, + .executing(_, _, .waitingForResponseHead): + preconditionFailure( + "Received no response head, but received a response end. Invalid state: \(self.state)" + ) + case .executing(let context, let requestState, .streamingBody(let source)): + self.state = .executing(context, requestState, .finished) + return .finishResponseStream(source, finalBody: newChunks) case .finished: // the request failed or was cancelled before, we can ignore all events return .none - - case .executing(_, _, .waitingForResponseIterator(_, next: .error)), - .executing(_, _, .waitingForResponseIterator(_, next: .endOfFile)), - .executing(_, _, .buffering(_, _, next: .error)), - .executing(_, _, .buffering(_, _, next: .endOfFile)), - .executing(_, _, .finished(_, _)): - preconditionFailure("Already received an eof or error before. Must not receive further events. Invalid state: \(self.state)") + case .executing(_, _, .finished): + preconditionFailure( + "Already received an eof or error before. Must not receive further events. Invalid state: \(self.state)" + ) } } enum DeadlineExceededAction { case none + case cancelSchedulerOnly(scheduler: HTTPRequestScheduler) /// fail response before head received. scheduler and executor are exclusive here. case cancel( requestContinuation: CheckedContinuation, @@ -690,7 +507,7 @@ extension Transaction { let error = HTTPClientError.deadlineExceeded switch self.state { case .initialized(let continuation): - self.state = .finished(error: error, nil) + self.state = .finished(error: error) return .cancel( requestContinuation: continuation, scheduler: nil, @@ -699,18 +516,16 @@ extension Transaction { ) case .queued(let continuation, let scheduler): - self.state = .finished(error: error, nil) - return .cancel( - requestContinuation: continuation, - scheduler: scheduler, - executor: nil, - bodyStreamContinuation: nil + self.state = .deadlineExceededWhileQueued(continuation) + return .cancelSchedulerOnly( + scheduler: scheduler ) - + case .deadlineExceededWhileQueued: + return .none case .executing(let context, let requestStreamState, .waitingForResponseHead): switch requestStreamState { case .paused(continuation: .some(let continuation)): - self.state = .finished(error: error, nil) + self.state = .finished(error: error) return .cancel( requestContinuation: context.continuation, scheduler: nil, @@ -718,7 +533,7 @@ extension Transaction { bodyStreamContinuation: continuation ) case .requestHeadSent, .finished, .producing, .paused(continuation: .none): - self.state = .finished(error: error, nil) + self.state = .finished(error: error) return .cancel( requestContinuation: context.continuation, scheduler: nil, @@ -738,4 +553,5 @@ extension Transaction { } } -#endif +@available(*, unavailable) +extension Transaction.StateMachine: Sendable {} diff --git a/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift b/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift index c2ce52eeb..a25c92e80 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift @@ -12,15 +12,19 @@ // //===----------------------------------------------------------------------===// -#if compiler(>=5.5.2) && canImport(_Concurrency) import Logging import NIOConcurrencyHelpers import NIOCore import NIOHTTP1 import NIOSSL +import Tracing @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -final class Transaction: @unchecked Sendable { +@usableFromInline +final class Transaction: + // until NIOLockedValueBox learns `sending` because StateMachine cannot be Sendable + @unchecked Sendable +{ let logger: Logger let request: HTTPClientRequest.Prepared @@ -29,8 +33,7 @@ final class Transaction: @unchecked Sendable { let preferredEventLoop: EventLoop let requestOptions: RequestOptions - private let stateLock = Lock() - private var state: StateMachine + private let state: NIOLockedValueBox init( request: HTTPClientRequest.Prepared, @@ -45,11 +48,11 @@ final class Transaction: @unchecked Sendable { self.logger = logger self.connectionDeadline = connectionDeadline self.preferredEventLoop = preferredEventLoop - self.state = StateMachine(responseContinuation) + self.state = NIOLockedValueBox(StateMachine(responseContinuation)) } func cancel() { - self.fail(HTTPClientError.cancelled) + self.fail(CancellationError()) } // MARK: Request body helpers @@ -57,13 +60,13 @@ final class Transaction: @unchecked Sendable { private func writeOnceAndOneTimeOnly(byteBuffer: ByteBuffer) { // This method is synchronously invoked after sending the request head. For this reason we // can make a number of assumptions, how the state machine will react. - let writeAction = self.stateLock.withLock { - self.state.writeNextRequestPart() + let writeAction = self.state.withLockedValue { state in + state.writeNextRequestPart() } switch writeAction { case .writeAndWait(let executor), .writeAndContinue(let executor): - executor.writeRequestBodyPart(.byteBuffer(byteBuffer), request: self) + executor.writeRequestBodyPart(.byteBuffer(byteBuffer), request: self, promise: nil) case .fail: // an error/cancellation has happened. we don't need to continue here @@ -75,9 +78,11 @@ final class Transaction: @unchecked Sendable { private func continueRequestBodyStream( _ allocator: ByteBufferAllocator, - next: @escaping ((ByteBufferAllocator) async throws -> ByteBuffer?) + makeAsyncIterator: @Sendable @escaping () -> ((ByteBufferAllocator) async throws -> ByteBuffer?) ) { Task { + let next = makeAsyncIterator() + do { while let part = try await next(allocator) { do { @@ -101,29 +106,49 @@ final class Transaction: @unchecked Sendable { struct BreakTheWriteLoopError: Swift.Error {} private func writeRequestBodyPart(_ part: ByteBuffer) async throws { - self.stateLock.lock() - switch self.state.writeNextRequestPart() { - case .writeAndContinue(let executor): - self.stateLock.unlock() - executor.writeRequestBodyPart(.byteBuffer(part), request: self) + let action = self.state.withLockedValue { state in + state.writeNextRequestPart() + } - case .writeAndWait(let executor): + switch action { + case .writeAndContinue(let executor): + executor.writeRequestBodyPart(.byteBuffer(part), request: self, promise: nil) + case .writeAndWait: + // Holding the lock here *should* be safe but because of a bug in the runtime + // it isn't, so drop the lock, create the continuation and try again. + // + // See https://github.com/swiftlang/swift/issues/85668 try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in - self.state.waitForRequestBodyDemand(continuation: continuation) - self.stateLock.unlock() + let action = self.state.withLockedValue { state in + // Check that nothing has changed between dropping and re-acquiring the lock. + let action = state.writeNextRequestPart() + switch action { + case .writeAndContinue, .fail: + () + case .writeAndWait: + state.waitForRequestBodyDemand(continuation: continuation) + } + return action + } - executor.writeRequestBodyPart(.byteBuffer(part), request: self) + switch action { + case .writeAndContinue(let executor): + executor.writeRequestBodyPart(.byteBuffer(part), request: self, promise: nil) + continuation.resume() + case .writeAndWait(let executor): + executor.writeRequestBodyPart(.byteBuffer(part), request: self, promise: nil) + case .fail: + continuation.resume(throwing: BreakTheWriteLoopError()) + } } - case .fail: - self.stateLock.unlock() throw BreakTheWriteLoopError() } } private func requestBodyStreamFinished() { - let finishAction = self.stateLock.withLock { - self.state.finishRequestBodyStream() + let finishAction = self.state.withLockedValue { state in + state.finishRequestBodyStream() } switch finishAction { @@ -131,9 +156,8 @@ final class Transaction: @unchecked Sendable { // an error/cancellation has happened. nothing to do. break - case .forwardStreamFinished(let executor, let succeedContinuation): - executor.finishRequestBodyStream(self) - succeedContinuation?.resume(returning: nil) + case .forwardStreamFinished(let executor): + executor.finishRequestBodyStream(self, promise: nil) } return } @@ -148,12 +172,12 @@ final class Transaction: @unchecked Sendable { @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension Transaction: HTTPSchedulableRequest { var poolKey: ConnectionPool.Key { self.request.poolKey } - var tlsConfiguration: TLSConfiguration? { return nil } - var requiredEventLoop: EventLoop? { return nil } + var tlsConfiguration: TLSConfiguration? { self.request.tlsConfiguration } + var requiredEventLoop: EventLoop? { nil } func requestWasQueued(_ scheduler: HTTPRequestScheduler) { - self.stateLock.withLock { - self.state.requestWasQueued(scheduler) + self.state.withLockedValue { state in + state.requestWasQueued(scheduler) } } } @@ -167,14 +191,16 @@ extension Transaction: HTTPExecutableRequest { // MARK: Request func willExecuteRequest(_ executor: HTTPRequestExecutor) { - let action = self.stateLock.withLock { - self.state.willExecuteRequest(executor) + let action = self.state.withLockedValue { state in + state.willExecuteRequest(executor) } switch action { case .cancel(let executor): executor.cancelRequest(self) - + case .cancelAndFail(let executor, let continuation, with: let error): + executor.cancelRequest(self) + continuation.resume(throwing: error) case .none: break } @@ -183,8 +209,8 @@ extension Transaction: HTTPExecutableRequest { func requestHeadSent() {} func resumeRequestBodyStream() { - let action = self.stateLock.withLock { - self.state.resumeRequestBodyStream() + let action = self.state.withLockedValue { state in + state.resumeRequestBodyStream() } switch action { @@ -192,10 +218,10 @@ extension Transaction: HTTPExecutableRequest { break case .startStream(let allocator): - switch self.request.body?.mode { - case .asyncSequence(_, let next): + switch self.request.body { + case .asyncSequence(_, let makeAsyncIterator): // it is safe to call this async here. it dispatches... - self.continueRequestBodyStream(allocator, next: next) + self.continueRequestBodyStream(allocator, makeAsyncIterator: makeAsyncIterator) case .byteBuffer(let byteBuffer): self.writeOnceAndOneTimeOnly(byteBuffer: byteBuffer) @@ -214,62 +240,71 @@ extension Transaction: HTTPExecutableRequest { } func pauseRequestBodyStream() { - self.stateLock.withLock { - self.state.pauseRequestBodyStream() + self.state.withLockedValue { state in + state.pauseRequestBodyStream() } } // MARK: Response func receiveResponseHead(_ head: HTTPResponseHead) { - let action = self.stateLock.withLock { - self.state.receiveResponseHead(head) + let action = self.state.withLockedValue { state in + state.receiveResponseHead(head, delegate: self) } switch action { case .none: break - case .succeedResponseHead(let head, let continuation): - let asyncResponse = HTTPClientResponse( - bag: self, + case .succeedResponseHead(let body, let continuation): + let response = HTTPClientResponse( + requestMethod: self.requestHead.method, version: head.version, status: head.status, - headers: head.headers + headers: head.headers, + body: body, + history: [] ) - continuation.resume(returning: asyncResponse) + continuation.resume(returning: response) } } func receiveResponseBodyParts(_ buffer: CircularBuffer) { - let action = self.stateLock.withLock { - self.state.receiveResponseBodyParts(buffer) + let action = self.state.withLockedValue { state in + state.receiveResponseBodyParts(buffer) } switch action { case .none: break - case .succeedContinuation(let continuation, let bytes): - continuation.resume(returning: bytes) + case .yieldResponseBodyParts(let source, let responseBodyParts, let executer): + switch source.yield(contentsOf: responseBodyParts) { + case .dropped, .stopProducing: + break + case .produceMore: + executer.demandResponseBodyStream(self) + } } } func succeedRequest(_ buffer: CircularBuffer?) { - let succeedAction = self.stateLock.withLock { - self.state.succeedRequest(buffer) + let succeedAction = self.state.withLockedValue { state in + state.succeedRequest(buffer) } switch succeedAction { - case .finishResponseStream(let continuation): - continuation.resume(returning: nil) - case .succeedContinuation(let continuation, let byteBuffer): - continuation.resume(returning: byteBuffer) + case .finishResponseStream(let source, let finalResponse): + if let finalResponse = finalResponse { + _ = source.yield(contentsOf: finalResponse) + } + source.finish() + case .none: break } } func fail(_ error: Error) { - let action = self.stateLock.withLock { - self.state.fail(error) + let action = self.state.withLockedValue { state in + state.fail(error) } self.performFailAction(action) } @@ -282,12 +317,12 @@ extension Transaction: HTTPExecutableRequest { case .failResponseHead(let continuation, let error, let scheduler, let executor, let bodyStreamContinuation): continuation.resume(throwing: error) bodyStreamContinuation?.resume(throwing: error) - scheduler?.cancelRequest(self) // NOTE: scheduler and executor are exclusive here + scheduler?.cancelRequest(self) // NOTE: scheduler and executor are exclusive here executor?.cancelRequest(self) - case .failResponseStream(let continuation, let error, let executor, let bodyStreamContinuation): - continuation.resume(throwing: error) - bodyStreamContinuation?.resume(throwing: error) + case .failResponseStream(let source, let error, let executor, let requestBodyStreamContinuation): + source.finish(error) + requestBodyStreamContinuation?.resume(throwing: error) executor.cancelRequest(self) case .failRequestStreamContinuation(let bodyStreamContinuation, let error): @@ -296,8 +331,8 @@ extension Transaction: HTTPExecutableRequest { } func deadlineExceeded() { - let action = self.stateLock.withLock { - self.state.deadlineExceeded() + let action = self.state.withLockedValue { state in + state.deadlineExceeded() } self.performDeadlineExceededAction(action) } @@ -309,7 +344,8 @@ extension Transaction: HTTPExecutableRequest { scheduler?.cancelRequest(self) executor?.cancelRequest(self) bodyStreamContinuation?.resume(throwing: HTTPClientError.deadlineExceeded) - + case .cancelSchedulerOnly(let scheduler): + scheduler.cancelRequest(self) case .none: break } @@ -317,46 +353,22 @@ extension Transaction: HTTPExecutableRequest { } @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -extension Transaction { - func responseBodyDeinited() { - let deinitedAction = self.stateLock.withLock { - self.state.responseBodyDeinited() +extension Transaction: NIOAsyncSequenceProducerDelegate { + @usableFromInline + func produceMore() { + let action = self.state.withLockedValue { state in + state.produceMore() } - - switch deinitedAction { - case .cancel(let executor): - executor.cancelRequest(self) + switch action { case .none: break + case .requestMoreResponseBodyParts(let executer): + executer.demandResponseBodyStream(self) } } - func nextResponsePart(streamID: HTTPClientResponse.Body.IteratorStream.ID) async throws -> ByteBuffer? { - try await withCheckedThrowingContinuation { continuation in - let action = self.stateLock.withLock { - self.state.consumeNextResponsePart(streamID: streamID, continuation: continuation) - } - switch action { - case .succeedContinuation(let continuation, let result): - continuation.resume(returning: result) - - case .failContinuation(let continuation, let error): - continuation.resume(throwing: error) - - case .askExecutorForMore(let executor): - executor.demandResponseBodyStream(self) - - case .none: - return - } - } - } - - func responseBodyIteratorDeinited(streamID: HTTPClientResponse.Body.IteratorStream.ID) { - let action = self.stateLock.withLock { - self.state.responseBodyIteratorDeinited(streamID: streamID) - } - self.performFailAction(action) + @usableFromInline + func didTerminate() { + self.fail(HTTPClientError.cancelled) } } -#endif diff --git a/Sources/AsyncHTTPClient/Base64.swift b/Sources/AsyncHTTPClient/Base64.swift index dbbf742ab..4d2ddcc49 100644 --- a/Sources/AsyncHTTPClient/Base64.swift +++ b/Sources/AsyncHTTPClient/Base64.swift @@ -19,156 +19,156 @@ extension String { - /// Base64 encode a collection of UInt8 to a string, without the use of Foundation. - @inlinable - init(base64Encoding bytes: Buffer) - where Buffer.Element == UInt8 - { - self = Base64.encode(bytes: bytes) - } + /// Base64 encode a collection of UInt8 to a string, without the use of Foundation. + @inlinable + init(base64Encoding bytes: Buffer) + where Buffer.Element == UInt8 { + self = Base64.encode(bytes: bytes) + } } +// swift-format-ignore: DontRepeatTypeInStaticProperties @usableFromInline -internal struct Base64 { - - @inlinable - static func encode(bytes: Buffer) - -> String where Buffer.Element == UInt8 - { - guard !bytes.isEmpty else { - return "" - } - // In Base64, 3 bytes become 4 output characters, and we pad to the - // nearest multiple of four. - let base64StringLength = ((bytes.count + 2) / 3) * 4 - let alphabet = Base64.encodeBase64 - - return String(customUnsafeUninitializedCapacity: base64StringLength) { backingStorage in - var input = bytes.makeIterator() - var offset = 0 - while let firstByte = input.next() { - let secondByte = input.next() - let thirdByte = input.next() - - backingStorage[offset] = Base64.encode(alphabet: alphabet, firstByte: firstByte) - backingStorage[offset + 1] = Base64.encode(alphabet: alphabet, firstByte: firstByte, secondByte: secondByte) - backingStorage[offset + 2] = Base64.encode(alphabet: alphabet, secondByte: secondByte, thirdByte: thirdByte) - backingStorage[offset + 3] = Base64.encode(alphabet: alphabet, thirdByte: thirdByte) - offset += 4 - } - return offset +internal struct Base64: Sendable { + + @inlinable + static func encode( + bytes: Buffer + ) + -> String where Buffer.Element == UInt8 + { + guard !bytes.isEmpty else { + return "" + } + // In Base64, 3 bytes become 4 output characters, and we pad to the + // nearest multiple of four. + let base64StringLength = ((bytes.count + 2) / 3) * 4 + let alphabet = Base64.encodeBase64 + + return String(customUnsafeUninitializedCapacity: base64StringLength) { backingStorage in + var input = bytes.makeIterator() + var offset = 0 + while let firstByte = input.next() { + let secondByte = input.next() + let thirdByte = input.next() + + backingStorage[offset] = Base64.encode(alphabet: alphabet, firstByte: firstByte) + backingStorage[offset + 1] = Base64.encode( + alphabet: alphabet, + firstByte: firstByte, + secondByte: secondByte + ) + backingStorage[offset + 2] = Base64.encode( + alphabet: alphabet, + secondByte: secondByte, + thirdByte: thirdByte + ) + backingStorage[offset + 3] = Base64.encode(alphabet: alphabet, thirdByte: thirdByte) + offset += 4 + } + return offset + } } - } - - // MARK: Internal - - // The base64 unicode table. - @usableFromInline - static let encodeBase64: [UInt8] = [ - UInt8(ascii: "A"), UInt8(ascii: "B"), UInt8(ascii: "C"), UInt8(ascii: "D"), - UInt8(ascii: "E"), UInt8(ascii: "F"), UInt8(ascii: "G"), UInt8(ascii: "H"), - UInt8(ascii: "I"), UInt8(ascii: "J"), UInt8(ascii: "K"), UInt8(ascii: "L"), - UInt8(ascii: "M"), UInt8(ascii: "N"), UInt8(ascii: "O"), UInt8(ascii: "P"), - UInt8(ascii: "Q"), UInt8(ascii: "R"), UInt8(ascii: "S"), UInt8(ascii: "T"), - UInt8(ascii: "U"), UInt8(ascii: "V"), UInt8(ascii: "W"), UInt8(ascii: "X"), - UInt8(ascii: "Y"), UInt8(ascii: "Z"), UInt8(ascii: "a"), UInt8(ascii: "b"), - UInt8(ascii: "c"), UInt8(ascii: "d"), UInt8(ascii: "e"), UInt8(ascii: "f"), - UInt8(ascii: "g"), UInt8(ascii: "h"), UInt8(ascii: "i"), UInt8(ascii: "j"), - UInt8(ascii: "k"), UInt8(ascii: "l"), UInt8(ascii: "m"), UInt8(ascii: "n"), - UInt8(ascii: "o"), UInt8(ascii: "p"), UInt8(ascii: "q"), UInt8(ascii: "r"), - UInt8(ascii: "s"), UInt8(ascii: "t"), UInt8(ascii: "u"), UInt8(ascii: "v"), - UInt8(ascii: "w"), UInt8(ascii: "x"), UInt8(ascii: "y"), UInt8(ascii: "z"), - UInt8(ascii: "0"), UInt8(ascii: "1"), UInt8(ascii: "2"), UInt8(ascii: "3"), - UInt8(ascii: "4"), UInt8(ascii: "5"), UInt8(ascii: "6"), UInt8(ascii: "7"), - UInt8(ascii: "8"), UInt8(ascii: "9"), UInt8(ascii: "+"), UInt8(ascii: "/"), - ] - - static let encodePaddingCharacter: UInt8 = UInt8(ascii: "=") - - @usableFromInline - static func encode(alphabet: [UInt8], firstByte: UInt8) -> UInt8 { - let index = firstByte >> 2 - return alphabet[Int(index)] - } - - @usableFromInline - static func encode(alphabet: [UInt8], firstByte: UInt8, secondByte: UInt8?) -> UInt8 { - var index = (firstByte & 0b00000011) << 4 - if let secondByte = secondByte { - index += (secondByte & 0b11110000) >> 4 + + // MARK: Internal + + // The base64 unicode table. + @usableFromInline + static let encodeBase64: [UInt8] = [ + UInt8(ascii: "A"), UInt8(ascii: "B"), UInt8(ascii: "C"), UInt8(ascii: "D"), + UInt8(ascii: "E"), UInt8(ascii: "F"), UInt8(ascii: "G"), UInt8(ascii: "H"), + UInt8(ascii: "I"), UInt8(ascii: "J"), UInt8(ascii: "K"), UInt8(ascii: "L"), + UInt8(ascii: "M"), UInt8(ascii: "N"), UInt8(ascii: "O"), UInt8(ascii: "P"), + UInt8(ascii: "Q"), UInt8(ascii: "R"), UInt8(ascii: "S"), UInt8(ascii: "T"), + UInt8(ascii: "U"), UInt8(ascii: "V"), UInt8(ascii: "W"), UInt8(ascii: "X"), + UInt8(ascii: "Y"), UInt8(ascii: "Z"), UInt8(ascii: "a"), UInt8(ascii: "b"), + UInt8(ascii: "c"), UInt8(ascii: "d"), UInt8(ascii: "e"), UInt8(ascii: "f"), + UInt8(ascii: "g"), UInt8(ascii: "h"), UInt8(ascii: "i"), UInt8(ascii: "j"), + UInt8(ascii: "k"), UInt8(ascii: "l"), UInt8(ascii: "m"), UInt8(ascii: "n"), + UInt8(ascii: "o"), UInt8(ascii: "p"), UInt8(ascii: "q"), UInt8(ascii: "r"), + UInt8(ascii: "s"), UInt8(ascii: "t"), UInt8(ascii: "u"), UInt8(ascii: "v"), + UInt8(ascii: "w"), UInt8(ascii: "x"), UInt8(ascii: "y"), UInt8(ascii: "z"), + UInt8(ascii: "0"), UInt8(ascii: "1"), UInt8(ascii: "2"), UInt8(ascii: "3"), + UInt8(ascii: "4"), UInt8(ascii: "5"), UInt8(ascii: "6"), UInt8(ascii: "7"), + UInt8(ascii: "8"), UInt8(ascii: "9"), UInt8(ascii: "+"), UInt8(ascii: "/"), + ] + + static let encodePaddingCharacter: UInt8 = UInt8(ascii: "=") + + @usableFromInline + static func encode(alphabet: [UInt8], firstByte: UInt8) -> UInt8 { + let index = firstByte >> 2 + return alphabet[Int(index)] } - return alphabet[Int(index)] - } - - @usableFromInline - static func encode(alphabet: [UInt8], secondByte: UInt8?, thirdByte: UInt8?) -> UInt8 { - guard let secondByte = secondByte else { - // No second byte means we are just emitting padding. - return Base64.encodePaddingCharacter + + @usableFromInline + static func encode(alphabet: [UInt8], firstByte: UInt8, secondByte: UInt8?) -> UInt8 { + var index = (firstByte & 0b00000011) << 4 + if let secondByte = secondByte { + index += (secondByte & 0b11110000) >> 4 + } + return alphabet[Int(index)] } - var index = (secondByte & 0b00001111) << 2 - if let thirdByte = thirdByte { - index += (thirdByte & 0b11000000) >> 6 + + @usableFromInline + static func encode(alphabet: [UInt8], secondByte: UInt8?, thirdByte: UInt8?) -> UInt8 { + guard let secondByte = secondByte else { + // No second byte means we are just emitting padding. + return Base64.encodePaddingCharacter + } + var index = (secondByte & 0b00001111) << 2 + if let thirdByte = thirdByte { + index += (thirdByte & 0b11000000) >> 6 + } + return alphabet[Int(index)] } - return alphabet[Int(index)] - } - - @usableFromInline - static func encode(alphabet: [UInt8], thirdByte: UInt8?) -> UInt8 { - guard let thirdByte = thirdByte else { - // No third byte means just padding. - return Base64.encodePaddingCharacter + + @usableFromInline + static func encode(alphabet: [UInt8], thirdByte: UInt8?) -> UInt8 { + guard let thirdByte = thirdByte else { + // No third byte means just padding. + return Base64.encodePaddingCharacter + } + let index = thirdByte & 0b00111111 + return alphabet[Int(index)] } - let index = thirdByte & 0b00111111 - return alphabet[Int(index)] - } } extension String { - /// This is a backport of a proposed String initializer that will allow writing directly into an uninitialized String's backing memory. - /// - /// As this API does not exist prior to 5.3 on Linux, or on older Apple platforms, we fake it out with a pointer and accept the extra copy. - @inlinable - init(backportUnsafeUninitializedCapacity capacity: Int, - initializingUTF8With initializer: (_ buffer: UnsafeMutableBufferPointer) throws -> Int) rethrows { - // The buffer will store zero terminated C string - let buffer = UnsafeMutableBufferPointer.allocate(capacity: capacity + 1) - defer { - buffer.deallocate() + /// This is a backport of a proposed String initializer that will allow writing directly into an uninitialized String's backing memory. + /// + /// As this API does not exist prior to 5.3 on Linux, or on older Apple platforms, we fake it out with a pointer and accept the extra copy. + @inlinable + init( + backportUnsafeUninitializedCapacity capacity: Int, + initializingUTF8With initializer: (_ buffer: UnsafeMutableBufferPointer) throws -> Int + ) rethrows { + // The buffer will store zero terminated C string + let buffer = UnsafeMutableBufferPointer.allocate(capacity: capacity + 1) + defer { + buffer.deallocate() + } + + let initializedCount = try initializer(buffer) + precondition(initializedCount <= capacity, "Overran buffer in initializer!") + // add zero termination + buffer[initializedCount] = 0 + + self = String(cString: buffer.baseAddress!) } - - let initializedCount = try initializer(buffer) - precondition(initializedCount <= capacity, "Overran buffer in initializer!") - // add zero termination - buffer[initializedCount] = 0 - - self = String(cString: buffer.baseAddress!) - } } -// Frustratingly, Swift 5.3 shipped before the macOS 11 SDK did, so we cannot gate the availability of -// this declaration on having the 5.3 compiler. This has caused a number of build issues. While updating -// to newer Xcodes does work, we can save ourselves some hassle and just wait until 5.4 to get this -// enhancement on Apple platforms. -#if (compiler(>=5.3) && !(os(macOS) || os(iOS) || os(tvOS) || os(watchOS))) || compiler(>=5.4) extension String { - @inlinable - init(customUnsafeUninitializedCapacity capacity: Int, - initializingUTF8With initializer: (_ buffer: UnsafeMutableBufferPointer) throws -> Int) rethrows { - if #available(macOS 11.0, iOS 14.0, tvOS 14.0, watchOS 7.0, *) { - try self.init(unsafeUninitializedCapacity: capacity, initializingUTF8With: initializer) - } else { - try self.init(backportUnsafeUninitializedCapacity: capacity, initializingUTF8With: initializer) + @inlinable + init( + customUnsafeUninitializedCapacity capacity: Int, + initializingUTF8With initializer: (_ buffer: UnsafeMutableBufferPointer) throws -> Int + ) rethrows { + if #available(macOS 11.0, iOS 14.0, tvOS 14.0, watchOS 7.0, *) { + try self.init(unsafeUninitializedCapacity: capacity, initializingUTF8With: initializer) + } else { + try self.init(backportUnsafeUninitializedCapacity: capacity, initializingUTF8With: initializer) + } } - } -} -#else -extension String { - @inlinable - init(customUnsafeUninitializedCapacity capacity: Int, - initializingUTF8With initializer: (_ buffer: UnsafeMutableBufferPointer) throws -> Int) rethrows { - try self.init(backportUnsafeUninitializedCapacity: capacity, initializingUTF8With: initializer) - } } -#endif diff --git a/Sources/AsyncHTTPClient/BasicAuth.swift b/Sources/AsyncHTTPClient/BasicAuth.swift new file mode 100644 index 000000000..3e69f8277 --- /dev/null +++ b/Sources/AsyncHTTPClient/BasicAuth.swift @@ -0,0 +1,39 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2024 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Foundation +import NIOHTTP1 + +/// Generates base64 encoded username + password for http basic auth. +/// +/// - Parameters: +/// - username: the username to authenticate with +/// - password: authentication password associated with the username +/// - Returns: encoded credentials to use the Authorization: Basic http header. +func encodeBasicAuthCredentials(username: String, password: String) -> String { + var value = Data() + value.reserveCapacity(username.utf8.count + password.utf8.count + 1) + value.append(contentsOf: username.utf8) + value.append(UInt8(ascii: ":")) + value.append(contentsOf: password.utf8) + return value.base64EncodedString() +} + +extension HTTPHeaders { + /// Sets the basic auth header + mutating func setBasicAuth(username: String, password: String) { + let encoded = encodeBasicAuthCredentials(username: username, password: password) + self.replaceOrAdd(name: "Authorization", value: "Basic \(encoded)") + } +} diff --git a/Sources/AsyncHTTPClient/BestEffortHashableTLSConfiguration.swift b/Sources/AsyncHTTPClient/BestEffortHashableTLSConfiguration.swift index 58169f645..aca0ce235 100644 --- a/Sources/AsyncHTTPClient/BestEffortHashableTLSConfiguration.swift +++ b/Sources/AsyncHTTPClient/BestEffortHashableTLSConfiguration.swift @@ -27,6 +27,6 @@ struct BestEffortHashableTLSConfiguration: Hashable { } static func == (lhs: BestEffortHashableTLSConfiguration, rhs: BestEffortHashableTLSConfiguration) -> Bool { - return lhs.base.bestEffortEquals(rhs.base) + lhs.base.bestEffortEquals(rhs.base) } } diff --git a/Sources/AsyncHTTPClient/Configuration+BrowserLike.swift b/Sources/AsyncHTTPClient/Configuration+BrowserLike.swift new file mode 100644 index 000000000..5a0abdfad --- /dev/null +++ b/Sources/AsyncHTTPClient/Configuration+BrowserLike.swift @@ -0,0 +1,45 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2023 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +import NIOCore +import NIOHTTPCompression +import NIOSSL + +// swift-format-ignore: DontRepeatTypeInStaticProperties +extension HTTPClient.Configuration { + /// The ``HTTPClient/Configuration`` for ``HTTPClient/shared`` which tries to mimic the platform's default or prevalent browser as closely as possible. + /// + /// Don't rely on specific values of this configuration as they're subject to change. You can rely on them being somewhat sensible though. + /// + /// - note: At present, this configuration is nowhere close to a real browser configuration but in case of disagreements we will choose values that match + /// the default browser as closely as possible. + /// + /// Platform's default/prevalent browsers that we're trying to match (these might change over time): + /// - macOS: Safari + /// - iOS: Safari + /// - Android: Google Chrome + /// - Linux (non-Android): Google Chrome + public static var singletonConfiguration: HTTPClient.Configuration { + // To start with, let's go with these values. Obtained from Firefox's config. + HTTPClient.Configuration( + certificateVerification: .fullVerification, + redirectConfiguration: .follow(max: 20, allowCycles: false), + timeout: Timeout(connect: .seconds(90), read: .seconds(90)), + connectionPool: .seconds(600), + proxy: nil, + ignoreUncleanSSLShutdown: false, + decompression: .enabled(limit: .ratio(25)), + backgroundActivityLogger: nil + ) + } +} diff --git a/Sources/AsyncHTTPClient/ConnectionPool.swift b/Sources/AsyncHTTPClient/ConnectionPool.swift index 0dac50e5f..3b45eca05 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool.swift @@ -12,8 +12,32 @@ // //===----------------------------------------------------------------------===// +import CNIOLinux +import NIOCore import NIOSSL +#if canImport(Darwin) +import Darwin.C +#elseif canImport(Musl) +import Musl +#elseif canImport(Android) +import Android +#elseif os(Linux) || os(FreeBSD) +import Glibc +#else +#error("unsupported target operating system") +#endif + +extension String { + var isIPAddress: Bool { + var ipv4Address = in_addr() + var ipv6Address = in6_addr() + return self.withCString { host in + inet_pton(AF_INET, host, &ipv4Address) == 1 || inet_pton(AF_INET6, host, &ipv6Address) == 1 + } + } +} + enum ConnectionPool { /// Used by the `ConnectionPool` to index its `HTTP1ConnectionProvider`s /// @@ -24,15 +48,18 @@ enum ConnectionPool { var scheme: Scheme var connectionTarget: ConnectionTarget private var tlsConfiguration: BestEffortHashableTLSConfiguration? + var serverNameIndicatorOverride: String? init( scheme: Scheme, connectionTarget: ConnectionTarget, - tlsConfiguration: BestEffortHashableTLSConfiguration? = nil + tlsConfiguration: BestEffortHashableTLSConfiguration? = nil, + serverNameIndicatorOverride: String? ) { self.scheme = scheme self.connectionTarget = connectionTarget self.tlsConfiguration = tlsConfiguration + self.serverNameIndicatorOverride = serverNameIndicatorOverride } var description: String { @@ -43,31 +70,50 @@ enum ConnectionPool { switch self.connectionTarget { case .ipAddress(let serialization, let addr): hostDescription = "\(serialization):\(addr.port!)" - case .domain(let domain, port: let port): + case .domain(let domain, let port): hostDescription = "\(domain):\(port)" case .unixSocket(let socketPath): hostDescription = socketPath } - return "\(self.scheme)://\(hostDescription) TLS-hash: \(hash)" + return + "\(self.scheme)://\(hostDescription)\(self.serverNameIndicatorOverride.map { " SNI: \($0)" } ?? "") TLS-hash: \(hash)" } } } +extension DeconstructedURL { + func applyDNSOverride(_ dnsOverride: [String: String]) -> (ConnectionTarget, serverNameIndicatorOverride: String?) { + guard + let originalHost = self.connectionTarget.host, + let hostOverride = dnsOverride[originalHost] + else { + return (self.connectionTarget, nil) + } + return ( + .init(remoteHost: hostOverride, port: self.connectionTarget.port ?? self.scheme.defaultPort), + serverNameIndicatorOverride: originalHost.isIPAddress ? nil : originalHost + ) + } +} + extension ConnectionPool.Key { - init(url: DeconstructedURL, tlsConfiguration: TLSConfiguration?) { + init(url: DeconstructedURL, tlsConfiguration: TLSConfiguration?, dnsOverride: [String: String]) { + let (connectionTarget, serverNameIndicatorOverride) = url.applyDNSOverride(dnsOverride) self.init( scheme: url.scheme, - connectionTarget: url.connectionTarget, + connectionTarget: connectionTarget, tlsConfiguration: tlsConfiguration.map { BestEffortHashableTLSConfiguration(wrapping: $0) - } + }, + serverNameIndicatorOverride: serverNameIndicatorOverride ) } - init(_ request: HTTPClient.Request) { + init(_ request: HTTPClient.Request, dnsOverride: [String: String] = [:]) { self.init( url: request.deconstructedURL, - tlsConfiguration: request.tlsConfiguration + tlsConfiguration: request.tlsConfiguration, + dnsOverride: dnsOverride ) } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/HTTP1ProxyConnectHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/HTTP1ProxyConnectHandler.swift index 7340a59ea..1636fe379 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/HTTP1ProxyConnectHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/HTTP1ProxyConnectHandler.swift @@ -42,7 +42,7 @@ final class HTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableChannelHand private var proxyEstablishedPromise: EventLoopPromise? var proxyEstablishedFuture: EventLoopFuture? { - return self.proxyEstablishedPromise?.futureResult + self.proxyEstablishedPromise?.futureResult } convenience init( @@ -53,10 +53,10 @@ final class HTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableChannelHand let targetHost: String let targetPort: Int switch target { - case .ipAddress(serialization: let serialization, address: let address): + case .ipAddress(let serialization, let address): targetHost = serialization targetPort = address.port! - case .domain(name: let domain, port: let port): + case .domain(name: let domain, let port): targetHost = domain targetPort = port case .unixSocket: @@ -70,10 +70,12 @@ final class HTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableChannelHand ) } - init(targetHost: String, - targetPort: Int, - proxyAuthorization: HTTPClient.Authorization?, - deadline: NIODeadline) { + init( + targetHost: String, + targetPort: Int, + proxyAuthorization: HTTPClient.Authorization?, + deadline: NIODeadline + ) { self.targetHost = targetHost self.targetPort = targetPort self.proxyAuthorization = proxyAuthorization @@ -135,7 +137,7 @@ final class HTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableChannelHand return } - let timeout = context.eventLoop.scheduleTask(deadline: self.deadline) { + let timeout = context.eventLoop.assumeIsolated().scheduleTask(deadline: self.deadline) { switch self.state { case .initialized: preconditionFailure("How can we have a scheduled timeout, if the connection is not even up?") @@ -155,6 +157,7 @@ final class HTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableChannelHand method: .CONNECT, uri: "\(self.targetHost):\(self.targetPort)" ) + head.headers.replaceOrAdd(name: "host", value: "\(self.targetHost)") if let authorization = self.proxyAuthorization { head.headers.replaceOrAdd(name: "proxy-authorization", value: authorization.headerValue) } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/SOCKSEventsHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/SOCKSEventsHandler.swift index 5a46f44a7..7458627fd 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/SOCKSEventsHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/SOCKSEventsHandler.swift @@ -31,7 +31,7 @@ final class SOCKSEventsHandler: ChannelInboundHandler, RemovableChannelHandler { private var socksEstablishedPromise: EventLoopPromise? var socksEstablishedFuture: EventLoopFuture? { - return self.socksEstablishedPromise?.futureResult + self.socksEstablishedPromise?.futureResult } private let deadline: NIODeadline @@ -99,7 +99,7 @@ final class SOCKSEventsHandler: ChannelInboundHandler, RemovableChannelHandler { return } - let scheduled = context.eventLoop.scheduleTask(deadline: self.deadline) { + let scheduled = context.eventLoop.assumeIsolated().scheduleTask(deadline: self.deadline) { switch self.state { case .initialized, .channelActive: // close the connection, if the handshake timed out diff --git a/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/TLSEventsHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/TLSEventsHandler.swift index aab26fda8..d210b2747 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/TLSEventsHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/TLSEventsHandler.swift @@ -31,7 +31,7 @@ final class TLSEventsHandler: ChannelInboundHandler, RemovableChannelHandler { private var tlsEstablishedPromise: EventLoopPromise? var tlsEstablishedFuture: EventLoopFuture? { - return self.tlsEstablishedPromise?.futureResult + self.tlsEstablishedPromise?.futureResult } private let deadline: NIODeadline? @@ -104,7 +104,7 @@ final class TLSEventsHandler: ChannelInboundHandler, RemovableChannelHandler { var scheduled: Scheduled? if let deadline = deadline { - scheduled = context.eventLoop.scheduleTask(deadline: deadline) { + scheduled = context.eventLoop.assumeIsolated().scheduleTask(deadline: deadline) { switch self.state { case .initialized, .channelActive: // close the connection, if the handshake timed out diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift index 9d1a3b5fd..191517c71 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift @@ -35,16 +35,24 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { didSet { if let newRequest = self.request { var requestLogger = newRequest.logger - requestLogger[metadataKey: "ahc-connection-id"] = "\(self.connection.id)" - requestLogger[metadataKey: "ahc-el"] = "\(self.connection.channel.eventLoop)" + requestLogger[metadataKey: "ahc-connection-id"] = self.connectionIdLoggerMetadata + requestLogger[metadataKey: "ahc-el"] = self.eventLoopDescription self.logger = requestLogger if let idleReadTimeout = newRequest.requestOptions.idleReadTimeout { self.idleReadTimeoutStateMachine = .init(timeAmount: idleReadTimeout) } + + if let idleWriteTimeout = newRequest.requestOptions.idleWriteTimeout { + self.idleWriteTimeoutStateMachine = .init( + timeAmount: idleWriteTimeout, + isWritabilityEnabled: self.channelContext?.channel.isWritable ?? false + ) + } } else { self.logger = self.backgroundLogger self.idleReadTimeoutStateMachine = nil + self.idleWriteTimeoutStateMachine = nil } } } @@ -52,22 +60,28 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { private var idleReadTimeoutStateMachine: IdleReadStateMachine? private var idleReadTimeoutTimer: Scheduled? + private var idleWriteTimeoutStateMachine: IdleWriteStateMachine? + private var idleWriteTimeoutTimer: Scheduled? + /// Cancelling a task in NIO does *not* guarantee that the task will not execute under certain race conditions. /// We therefore give each timer an ID and increase the ID every time we reset or cancel it. /// We check in the task if the timer ID has changed in the meantime and do not execute any action if has changed. private var currentIdleReadTimeoutTimerID: Int = 0 + private var currentIdleWriteTimeoutTimerID: Int = 0 private let backgroundLogger: Logger private var logger: Logger + private let eventLoop: EventLoop + private let eventLoopDescription: Logger.MetadataValue + private let connectionIdLoggerMetadata: Logger.MetadataValue - let connection: HTTP1Connection - let eventLoop: EventLoop - - init(connection: HTTP1Connection, eventLoop: EventLoop, logger: Logger) { - self.connection = connection + var onConnectionIdle: () -> Void = {} + init(eventLoop: EventLoop, backgroundLogger: Logger, connectionIdLoggerMetadata: Logger.MetadataValue) { self.eventLoop = eventLoop - self.backgroundLogger = logger - self.logger = self.backgroundLogger + self.eventLoopDescription = "\(eventLoop.description)" + self.backgroundLogger = backgroundLogger + self.logger = backgroundLogger + self.connectionIdLoggerMetadata = connectionIdLoggerMetadata } func handlerAdded(context: ChannelHandlerContext) { @@ -86,9 +100,12 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { // MARK: Channel Inbound Handler func channelActive(context: ChannelHandlerContext) { - self.logger.trace("Channel active", metadata: [ - "ahc-channel-writable": "\(context.channel.isWritable)", - ]) + self.logger.trace( + "Channel active", + metadata: [ + "ahc-channel-writable": "\(context.channel.isWritable)" + ] + ) let action = self.state.channelActive(isWritable: context.channel.isWritable) self.run(action, context: context) @@ -102,20 +119,31 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { } func channelWritabilityChanged(context: ChannelHandlerContext) { - self.logger.trace("Channel writability changed", metadata: [ - "ahc-channel-writable": "\(context.channel.isWritable)", - ]) + self.logger.trace( + "Channel writability changed", + metadata: [ + "ahc-channel-writable": "\(context.channel.isWritable)" + ] + ) + + if let timeoutAction = self.idleWriteTimeoutStateMachine?.channelWritabilityChanged(context: context) { + self.runTimeoutAction(timeoutAction, context: context) + } let action = self.state.writabilityChanged(writable: context.channel.isWritable) self.run(action, context: context) + context.fireChannelWritabilityChanged() } func channelRead(context: ChannelHandlerContext, data: NIOAny) { let httpPart = self.unwrapInboundIn(data) - self.logger.trace("HTTP response part received", metadata: [ - "ahc-http-part": "\(httpPart)", - ]) + self.logger.trace( + "HTTP response part received", + metadata: [ + "ahc-http-part": "\(httpPart)" + ] + ) if let timeoutAction = self.idleReadTimeoutStateMachine?.channelRead(httpPart) { self.runTimeoutAction(timeoutAction, context: context) @@ -133,9 +161,12 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { } func errorCaught(context: ChannelHandlerContext, error: Error) { - self.logger.trace("Channel error caught", metadata: [ - "ahc-error": "\(error)", - ]) + self.logger.trace( + "Channel error caught", + metadata: [ + "ahc-error": "\(error)" + ] + ) let action = self.state.errorHappened(error) self.run(action, context: context) @@ -149,7 +180,12 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { self.request = req self.logger.debug("Request was scheduled on connection") - req.willExecuteRequest(self) + + if let timeoutAction = self.idleWriteTimeoutStateMachine?.write() { + self.runTimeoutAction(timeoutAction, context: context) + } + + req.willExecuteRequest(self.requestExecutor) let action = self.state.runNewRequest( head: req.requestHead, @@ -182,17 +218,39 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { private func run(_ action: HTTP1ConnectionStateMachine.Action, context: ChannelHandlerContext) { switch action { - case .sendRequestHead(let head, startBody: let startBody): - self.sendRequestHead(head, startBody: startBody, context: context) + case .sendRequestHead(let head, let sendEnd): + self.sendRequestHead(head, sendEnd: sendEnd, context: context) + case .notifyRequestHeadSendSuccessfully(let resumeRequestBodyStream, let startIdleTimer): + // We can force unwrap the request here, as we have just validated in the state machine, + // that the request is neither failed nor finished yet + self.request!.requestHeadSent() + if resumeRequestBodyStream, let request = self.request { + // The above request head send notification might lead the request to mark itself as + // cancelled, which in turn might pop the request of the handler. For this reason we + // must check if the request is still present here. + request.resumeRequestBodyStream() + } + if startIdleTimer { + if let readTimeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { + self.runTimeoutAction(readTimeoutAction, context: context) + } - case .sendBodyPart(let part): - context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: nil) + if let writeTimeoutAction = self.idleWriteTimeoutStateMachine?.requestEndSent() { + self.runTimeoutAction(writeTimeoutAction, context: context) + } + } + case .sendBodyPart(let part, let writePromise): + context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: writePromise) - case .sendRequestEnd: - context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + case .sendRequestEnd(let writePromise): + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: writePromise) - if let timeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { - self.runTimeoutAction(timeoutAction, context: context) + if let readTimeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { + self.runTimeoutAction(readTimeoutAction, context: context) + } + + if let writeTimeoutAction = self.idleWriteTimeoutStateMachine?.requestEndSent() { + self.runTimeoutAction(writeTimeoutAction, context: context) } case .pauseRequestBodyStream: @@ -256,66 +314,80 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { let oldRequest = self.request! self.request = nil self.runTimeoutAction(.clearIdleReadTimeoutTimer, context: context) + self.runTimeoutAction(.clearIdleWriteTimeoutTimer, context: context) switch finalAction { case .close: context.close(promise: nil) - case .sendRequestEnd: - context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + oldRequest.succeedRequest(buffer) + case .sendRequestEnd(let writePromise, let shouldClose): + let writePromise = writePromise ?? context.eventLoop.makePromise(of: Void.self) + // We need to defer succeeding the old request to avoid ordering issues + writePromise.futureResult.hop(to: context.eventLoop).assumeIsolated().whenComplete { result in + switch result { + case .success: + // If our final action was `sendRequestEnd`, that means we've already received + // the complete response. As a result, once we've uploaded all the body parts + // we need to tell the pool that the connection is idle or, if we were asked to + // close when we're done, send the close. Either way, we then succeed the request + if shouldClose { + context.close(promise: nil) + } else { + self.onConnectionIdle() + } + + oldRequest.succeedRequest(buffer) + case .failure(let error): + context.close(promise: nil) + oldRequest.fail(error) + } + } + + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: writePromise) case .informConnectionIsIdle: - self.connection.taskCompleted() - case .none: - break + self.onConnectionIdle() + oldRequest.succeedRequest(buffer) } - oldRequest.succeedRequest(buffer) - case .failRequest(let error, let finalAction): // see comment in the `succeedRequest` case. let oldRequest = self.request! self.request = nil self.runTimeoutAction(.clearIdleReadTimeoutTimer, context: context) + self.runTimeoutAction(.clearIdleWriteTimeoutTimer, context: context) switch finalAction { - case .close: + case .close(let writePromise): context.close(promise: nil) - case .sendRequestEnd: - context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + writePromise?.fail(error) + oldRequest.fail(error) + case .informConnectionIsIdle: - self.connection.taskCompleted() + self.onConnectionIdle() + oldRequest.fail(error) + + case .failWritePromise(let writePromise): + writePromise?.fail(error) + oldRequest.fail(error) + case .none: - break + oldRequest.fail(error) } - oldRequest.fail(error) + case .failSendBodyPart(let error, let writePromise), .failSendStreamFinished(let error, let writePromise): + writePromise?.fail(error) } } - private func sendRequestHead(_ head: HTTPRequestHead, startBody: Bool, context: ChannelHandlerContext) { - if startBody { - context.writeAndFlush(self.wrapOutboundOut(.head(head)), promise: nil) - - // The above write might trigger an error, which may lead to a call to `errorCaught`, - // which in turn, may fail the request and pop it from the handler. For this reason - // we must check if the request is still present here. - guard let request = self.request else { return } - request.requestHeadSent() - request.resumeRequestBodyStream() - } else { + private func sendRequestHead(_ head: HTTPRequestHead, sendEnd: Bool, context: ChannelHandlerContext) { + if sendEnd { context.write(self.wrapOutboundOut(.head(head)), promise: nil) context.write(self.wrapOutboundOut(.end(nil)), promise: nil) context.flush() - - // The above write might trigger an error, which may lead to a call to `errorCaught`, - // which in turn, may fail the request and pop it from the handler. For this reason - // we must check if the request is still present here. - guard let request = self.request else { return } - request.requestHeadSent() - - if let timeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { - self.runTimeoutAction(timeoutAction, context: context) - } + } else { + context.writeAndFlush(self.wrapOutboundOut(.head(head)), promise: nil) } + self.run(self.state.headSent(), context: context) } private func runTimeoutAction(_ action: IdleReadStateMachine.Action, context: ChannelHandlerContext) { @@ -324,7 +396,7 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { assert(self.idleReadTimeoutTimer == nil, "Expected there is no timeout timer so far.") let timerID = self.currentIdleReadTimeoutTimerID - self.idleReadTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { + self.idleReadTimeoutTimer = self.eventLoop.assumeIsolated().scheduleTask(in: timeAmount) { guard self.currentIdleReadTimeoutTimerID == timerID else { return } let action = self.state.idleReadTimeoutTriggered() self.run(action, context: context) @@ -337,7 +409,7 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { self.currentIdleReadTimeoutTimerID &+= 1 let timerID = self.currentIdleReadTimeoutTimerID - self.idleReadTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { + self.idleReadTimeoutTimer = self.eventLoop.assumeIsolated().scheduleTask(in: timeAmount) { guard self.currentIdleReadTimeoutTimerID == timerID else { return } let action = self.state.idleReadTimeoutTriggered() self.run(action, context: context) @@ -353,33 +425,77 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { } } + private func runTimeoutAction(_ action: IdleWriteStateMachine.Action, context: ChannelHandlerContext) { + switch action { + case .startIdleWriteTimeoutTimer(let timeAmount): + assert(self.idleWriteTimeoutTimer == nil, "Expected there is no timeout timer so far.") + + let timerID = self.currentIdleWriteTimeoutTimerID + self.idleWriteTimeoutTimer = self.eventLoop.assumeIsolated().scheduleTask(in: timeAmount) { + guard self.currentIdleWriteTimeoutTimerID == timerID else { return } + let action = self.state.idleWriteTimeoutTriggered() + self.run(action, context: context) + } + case .resetIdleWriteTimeoutTimer(let timeAmount): + if let oldTimer = self.idleWriteTimeoutTimer { + oldTimer.cancel() + } + + self.currentIdleWriteTimeoutTimerID &+= 1 + let timerID = self.currentIdleWriteTimeoutTimerID + self.idleWriteTimeoutTimer = self.eventLoop.assumeIsolated().scheduleTask(in: timeAmount) { + guard self.currentIdleWriteTimeoutTimerID == timerID else { return } + let action = self.state.idleWriteTimeoutTriggered() + self.run(action, context: context) + } + case .clearIdleWriteTimeoutTimer: + if let oldTimer = self.idleWriteTimeoutTimer { + self.idleWriteTimeoutTimer = nil + self.currentIdleWriteTimeoutTimerID &+= 1 + oldTimer.cancel() + } + case .none: + break + } + } + // MARK: Private HTTPRequestExecutor - private func writeRequestBodyPart0(_ data: IOData, request: HTTPExecutableRequest) { + fileprivate func writeRequestBodyPart0( + _ data: IOData, + request: HTTPExecutableRequest, + promise: EventLoopPromise? + ) { guard self.request === request, let context = self.channelContext else { // Because the HTTPExecutableRequest may run in a different thread to our eventLoop, // calls from the HTTPExecutableRequest to our ChannelHandler may arrive here after // the request has been popped by the state machine or the ChannelHandler has been // removed from the Channel pipeline. This is a normal threading issue, noone has // screwed up. + promise?.fail(HTTPClientError.requestStreamCancelled) return } - let action = self.state.requestStreamPartReceived(data) + if let timeoutAction = self.idleWriteTimeoutStateMachine?.write() { + self.runTimeoutAction(timeoutAction, context: context) + } + + let action = self.state.requestStreamPartReceived(data, promise: promise) self.run(action, context: context) } - private func finishRequestBodyStream0(_ request: HTTPExecutableRequest) { + fileprivate func finishRequestBodyStream0(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { guard self.request === request, let context = self.channelContext else { // See code comment in `writeRequestBodyPart0` + promise?.fail(HTTPClientError.requestStreamCancelled) return } - let action = self.state.requestStreamFinished() + let action = self.state.requestStreamFinished(promise: promise) self.run(action, context: context) } - private func demandResponseBodyStream0(_ request: HTTPExecutableRequest) { + fileprivate func demandResponseBodyStream0(_ request: HTTPExecutableRequest) { guard self.request === request, let context = self.channelContext else { // See code comment in `writeRequestBodyPart0` return @@ -391,7 +507,7 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { self.run(action, context: context) } - private func cancelRequest0(_ request: HTTPExecutableRequest) { + fileprivate func cancelRequest0(_ request: HTTPExecutableRequest) { guard self.request === request, let context = self.channelContext else { // See code comment in `writeRequestBodyPart0` return @@ -399,48 +515,51 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { self.logger.trace("Request was cancelled") + if let timeoutAction = self.idleWriteTimeoutStateMachine?.cancelRequest() { + self.runTimeoutAction(timeoutAction, context: context) + } + let action = self.state.requestCancelled(closeConnection: true) self.run(action, context: context) } } -extension HTTP1ClientChannelHandler: HTTPRequestExecutor { - func writeRequestBodyPart(_ data: IOData, request: HTTPExecutableRequest) { - if self.eventLoop.inEventLoop { - self.writeRequestBodyPart0(data, request: request) - } else { - self.eventLoop.execute { - self.writeRequestBodyPart0(data, request: request) +@available(*, unavailable) +extension HTTP1ClientChannelHandler: Sendable {} + +extension HTTP1ClientChannelHandler { + var requestExecutor: RequestExecutor { + RequestExecutor(self) + } + + struct RequestExecutor: HTTPRequestExecutor, Sendable { + private let loopBound: NIOLoopBound + + init(_ handler: HTTP1ClientChannelHandler) { + self.loopBound = NIOLoopBound(handler, eventLoop: handler.eventLoop) + } + + func writeRequestBodyPart(_ data: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise?) { + self.loopBound.execute { + $0.writeRequestBodyPart0(data, request: request, promise: promise) } } - } - func finishRequestBodyStream(_ request: HTTPExecutableRequest) { - if self.eventLoop.inEventLoop { - self.finishRequestBodyStream0(request) - } else { - self.eventLoop.execute { - self.finishRequestBodyStream0(request) + func finishRequestBodyStream(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { + self.loopBound.execute { + $0.finishRequestBodyStream0(request, promise: promise) } } - } - func demandResponseBodyStream(_ request: HTTPExecutableRequest) { - if self.eventLoop.inEventLoop { - self.demandResponseBodyStream0(request) - } else { - self.eventLoop.execute { - self.demandResponseBodyStream0(request) + func demandResponseBodyStream(_ request: HTTPExecutableRequest) { + self.loopBound.execute { + $0.demandResponseBodyStream0(request) } } - } - func cancelRequest(_ request: HTTPExecutableRequest) { - if self.eventLoop.inEventLoop { - self.cancelRequest0(request) - } else { - self.eventLoop.execute { - self.cancelRequest0(request) + func cancelRequest(_ request: HTTPExecutableRequest) { + self.loopBound.execute { + $0.cancelRequest0(request) } } } @@ -508,3 +627,90 @@ struct IdleReadStateMachine { } } } + +struct IdleWriteStateMachine { + enum Action { + case startIdleWriteTimeoutTimer(TimeAmount) + case resetIdleWriteTimeoutTimer(TimeAmount) + case clearIdleWriteTimeoutTimer + case none + } + + enum State { + case waitingForRequestEnd + case waitingForWritabilityEnabled + case requestEndSent + } + + private var state: State + private let timeAmount: TimeAmount + + init(timeAmount: TimeAmount, isWritabilityEnabled: Bool) { + self.timeAmount = timeAmount + if isWritabilityEnabled { + self.state = .waitingForRequestEnd + } else { + self.state = .waitingForWritabilityEnabled + } + } + + mutating func cancelRequest() -> Action { + switch self.state { + case .waitingForRequestEnd, .waitingForWritabilityEnabled: + self.state = .requestEndSent + return .clearIdleWriteTimeoutTimer + case .requestEndSent: + return .none + } + } + + mutating func write() -> Action { + switch self.state { + case .waitingForRequestEnd: + return .resetIdleWriteTimeoutTimer(self.timeAmount) + case .waitingForWritabilityEnabled: + return .none + case .requestEndSent: + preconditionFailure("If the request end has been sent, we can't write more data.") + } + } + + mutating func requestEndSent() -> Action { + switch self.state { + case .waitingForRequestEnd: + self.state = .requestEndSent + return .clearIdleWriteTimeoutTimer + case .waitingForWritabilityEnabled: + self.state = .requestEndSent + return .none + case .requestEndSent: + return .none + } + } + + mutating func channelWritabilityChanged(context: ChannelHandlerContext) -> Action { + if context.channel.isWritable { + switch self.state { + case .waitingForRequestEnd: + preconditionFailure("If waiting for more data, the channel was already writable.") + case .waitingForWritabilityEnabled: + self.state = .waitingForRequestEnd + return .startIdleWriteTimeoutTimer(self.timeAmount) + case .requestEndSent: + return .none + } + } else { + switch self.state { + case .waitingForRequestEnd: + self.state = .waitingForWritabilityEnabled + return .clearIdleWriteTimeoutTimer + case .waitingForWritabilityEnabled: + preconditionFailure( + "If the channel was writable before, then we should have been waiting for more data." + ) + case .requestEndSent: + return .none + } + } + } +} diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1Connection.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1Connection.swift index 173ac79e4..6f64e0407 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1Connection.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1Connection.swift @@ -17,9 +17,9 @@ import NIOCore import NIOHTTP1 import NIOHTTPCompression -protocol HTTP1ConnectionDelegate { - func http1ConnectionReleased(_: HTTP1Connection) - func http1ConnectionClosed(_: HTTP1Connection) +protocol HTTP1ConnectionDelegate: Sendable { + func http1ConnectionReleased(_: HTTPConnectionPool.Connection.ID) + func http1ConnectionClosed(_: HTTPConnectionPool.Connection.ID) } final class HTTP1Connection { @@ -39,9 +39,11 @@ final class HTTP1Connection { let id: HTTPConnectionPool.Connection.ID - init(channel: Channel, - connectionID: HTTPConnectionPool.Connection.ID, - delegate: HTTP1ConnectionDelegate) { + init( + channel: Channel, + connectionID: HTTPConnectionPool.Connection.ID, + delegate: HTTP1ConnectionDelegate + ) { self.channel = channel self.id = connectionID self.delegate = delegate @@ -57,40 +59,53 @@ final class HTTP1Connection { channel: Channel, connectionID: HTTPConnectionPool.Connection.ID, delegate: HTTP1ConnectionDelegate, - configuration: HTTPClient.Configuration, + decompression: HTTPClient.Decompression, logger: Logger ) throws -> HTTP1Connection { let connection = HTTP1Connection(channel: channel, connectionID: connectionID, delegate: delegate) - try connection.start(configuration: configuration, logger: logger) + try connection.start(decompression: decompression, logger: logger) return connection } - func executeRequest(_ request: HTTPExecutableRequest) { - if self.channel.eventLoop.inEventLoop { - self.execute0(request: request) - } else { - self.channel.eventLoop.execute { - self.execute0(request: request) + var sendableView: SendableView { + SendableView(self) + } + + struct SendableView: Sendable { + private let connection: NIOLoopBound + let channel: Channel + let id: HTTPConnectionPool.Connection.ID + private var eventLoop: EventLoop { self.connection.eventLoop } + + init(_ connection: HTTP1Connection) { + self.connection = NIOLoopBound(connection, eventLoop: connection.channel.eventLoop) + self.id = connection.id + self.channel = connection.channel + } + + func executeRequest(_ request: HTTPExecutableRequest) { + self.connection.execute { + $0.execute0(request: request) } } - } - func shutdown() { - self.channel.triggerUserOutboundEvent(HTTPConnectionEvent.shutdownRequested, promise: nil) - } + func shutdown() { + self.channel.triggerUserOutboundEvent(HTTPConnectionEvent.shutdownRequested, promise: nil) + } - func close(promise: EventLoopPromise?) { - return self.channel.close(mode: .all, promise: promise) - } + func close(promise: EventLoopPromise?) { + self.channel.close(mode: .all, promise: promise) + } - func close() -> EventLoopFuture { - let promise = self.channel.eventLoop.makePromise(of: Void.self) - self.close(promise: promise) - return promise.futureResult + func close() -> EventLoopFuture { + let promise = self.eventLoop.makePromise(of: Void.self) + self.close(promise: promise) + return promise.futureResult + } } func taskCompleted() { - self.delegate.http1ConnectionReleased(self) + self.delegate.http1ConnectionReleased(self.id) } private func execute0(request: HTTPExecutableRequest) { @@ -98,10 +113,10 @@ final class HTTP1Connection { return request.fail(ChannelError.ioOnClosedChannel) } - self.channel.write(request, promise: nil) + self.channel.pipeline.syncOperations.write(NIOAny(request), promise: nil) } - private func start(configuration: HTTPClient.Configuration, logger: Logger) throws { + private func start(decompression: HTTPClient.Decompression, logger: Logger) throws { self.channel.eventLoop.assertInEventLoop() guard case .initialized = self.state else { @@ -109,9 +124,9 @@ final class HTTP1Connection { } self.state = .active - self.channel.closeFuture.whenComplete { _ in + self.channel.closeFuture.assumeIsolated().whenComplete { _ in self.state = .closed - self.delegate.http1ConnectionClosed(self) + self.delegate.http1ConnectionClosed(self.id) } do { @@ -127,16 +142,19 @@ final class HTTP1Connection { try sync.addHandler(requestEncoder) try sync.addHandler(ByteToMessageHandler(responseDecoder)) - if case .enabled(let limit) = configuration.decompression { + if case .enabled(let limit) = decompression { let decompressHandler = NIOHTTPResponseDecompressor(limit: limit) try sync.addHandler(decompressHandler) } let channelHandler = HTTP1ClientChannelHandler( - connection: self, eventLoop: channel.eventLoop, - logger: logger + backgroundLogger: logger, + connectionIdLoggerMetadata: "\(self.id)" ) + channelHandler.onConnectionIdle = { + self.taskCompleted() + } try sync.addHandler(channelHandler) } catch { @@ -145,3 +163,6 @@ final class HTTP1Connection { } } } + +@available(*, unavailable) +extension HTTP1Connection: Sendable {} diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift index 19825aec7..2cde1df3f 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift @@ -28,21 +28,44 @@ struct HTTP1ConnectionStateMachine { enum Action { /// A action to execute, when we consider a request "done". - enum FinalStreamAction { + enum FinalSuccessfulStreamAction { /// Close the connection case close /// If the server has replied, with a status of 200...300 before all data was sent, a request is considered succeeded, /// as soon as we wrote the request end onto the wire. - case sendRequestEnd + /// + /// The promise is an optional write promise. + /// + /// `shouldClose` records whether we have attached a Connection: close header to this request, and so the connection should + /// be terminated + case sendRequestEnd(EventLoopPromise?, shouldClose: Bool) /// Inform an observer that the connection has become idle case informConnectionIsIdle + } + + /// A action to execute, when we consider a request "done". + enum FinalFailedStreamAction { + /// Close the connection + /// + /// The promise is an optional write promise. + case close(EventLoopPromise?) + /// Inform an observer that the connection has become idle + case informConnectionIsIdle + /// Fail the write promise + case failWritePromise(EventLoopPromise?) /// Do nothing. case none } - case sendRequestHead(HTTPRequestHead, startBody: Bool) - case sendBodyPart(IOData) - case sendRequestEnd + case sendRequestHead(HTTPRequestHead, sendEnd: Bool) + case notifyRequestHeadSendSuccessfully( + resumeRequestBodyStream: Bool, + startIdleTimer: Bool + ) + case sendBodyPart(IOData, EventLoopPromise?) + case sendRequestEnd(EventLoopPromise?) + case failSendBodyPart(Error, EventLoopPromise?) + case failSendStreamFinished(Error, EventLoopPromise?) case pauseRequestBodyStream case resumeRequestBodyStream @@ -50,8 +73,8 @@ struct HTTP1ConnectionStateMachine { case forwardResponseHead(HTTPResponseHead, pauseRequestBodyStream: Bool) case forwardResponseBodyParts(CircularBuffer) - case failRequest(Error, FinalStreamAction) - case succeedRequest(FinalStreamAction, CircularBuffer) + case failRequest(Error, FinalFailedStreamAction) + case succeedRequest(FinalSuccessfulStreamAction, CircularBuffer) case read case close @@ -83,14 +106,14 @@ struct HTTP1ConnectionStateMachine { return .wait case .modifying: - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") } } mutating func channelInactive() -> Action { switch self.state { case .initialized: - preconditionFailure("A channel that isn't active, must not become inactive") + fatalError("A channel that isn't active, must not become inactive") case .inRequest(var requestStateMachine, close: _): return self.avoidingStateMachineCoW { state -> Action in @@ -107,7 +130,7 @@ struct HTTP1ConnectionStateMachine { return .wait case .modifying: - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") } } @@ -117,7 +140,7 @@ struct HTTP1ConnectionStateMachine { self.state = .closed return .fireChannelError(error, closeConnection: false) - case .inRequest(var requestStateMachine, close: let close): + case .inRequest(var requestStateMachine, let close): return self.avoidingStateMachineCoW { state -> Action in let action = requestStateMachine.errorHappened(error) state = .inRequest(requestStateMachine, close: close) @@ -132,7 +155,7 @@ struct HTTP1ConnectionStateMachine { return .fireChannelError(error, closeConnection: false) case .modifying: - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") } } @@ -150,7 +173,7 @@ struct HTTP1ConnectionStateMachine { } case .modifying: - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") } } @@ -159,15 +182,15 @@ struct HTTP1ConnectionStateMachine { metadata: RequestFramingMetadata ) -> Action { switch self.state { - case .initialized, .closing, .inRequest: + case .initialized, .inRequest: // These states are unreachable as the connection pool state machine has put the // connection into these states. In other words the connection pool state machine must // be aware about these states before the connection itself. For this reason the // connection pool state machine must not send a new request to the connection, if the // connection is `.initialized`, `.closing` or `.inRequest` - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") - case .closed: + case .closing, .closed: // The remote may have closed the connection and the connection pool state machine // was not updated yet because of a race condition. New request vs. marking connection // as closed. @@ -185,29 +208,29 @@ struct HTTP1ConnectionStateMachine { return self.state.modify(with: action) case .modifying: - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") } } - mutating func requestStreamPartReceived(_ part: IOData) -> Action { + mutating func requestStreamPartReceived(_ part: IOData, promise: EventLoopPromise?) -> Action { guard case .inRequest(var requestStateMachine, let close) = self.state else { - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") } return self.avoidingStateMachineCoW { state -> Action in - let action = requestStateMachine.requestStreamPartReceived(part) + let action = requestStateMachine.requestStreamPartReceived(part, promise: promise) state = .inRequest(requestStateMachine, close: close) return state.modify(with: action) } } - mutating func requestStreamFinished() -> Action { + mutating func requestStreamFinished(promise: EventLoopPromise?) -> Action { guard case .inRequest(var requestStateMachine, let close) = self.state else { - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") } return self.avoidingStateMachineCoW { state -> Action in - let action = requestStateMachine.requestStreamFinished() + let action = requestStateMachine.requestStreamFinished(promise: promise) state = .inRequest(requestStateMachine, close: close) return state.modify(with: action) } @@ -216,7 +239,9 @@ struct HTTP1ConnectionStateMachine { mutating func requestCancelled(closeConnection: Bool) -> Action { switch self.state { case .initialized: - preconditionFailure("This event must only happen, if the connection is leased. During startup this is impossible. Invalid state: \(self.state)") + fatalError( + "This event must only happen, if the connection is leased. During startup this is impossible. Invalid state: \(self.state)" + ) case .idle: if closeConnection { @@ -226,7 +251,7 @@ struct HTTP1ConnectionStateMachine { return .wait } - case .inRequest(var requestStateMachine, close: let close): + case .inRequest(var requestStateMachine, let close): return self.avoidingStateMachineCoW { state -> Action in let action = requestStateMachine.requestCancelled() state = .inRequest(requestStateMachine, close: close || closeConnection) @@ -237,7 +262,7 @@ struct HTTP1ConnectionStateMachine { return .wait case .modifying: - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") } } @@ -246,7 +271,7 @@ struct HTTP1ConnectionStateMachine { mutating func read() -> Action { switch self.state { case .initialized: - preconditionFailure("Why should we read something, if we are not connected yet") + fatalError("Why should we read something, if we are not connected yet") case .idle: return .read case .inRequest(var requestStateMachine, let close): @@ -261,14 +286,14 @@ struct HTTP1ConnectionStateMachine { return .read case .modifying: - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") } } mutating func channelRead(_ part: HTTPClientResponsePart) -> Action { switch self.state { case .initialized, .idle: - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") case .inRequest(var requestStateMachine, var close): return self.avoidingStateMachineCoW { state -> Action in @@ -287,7 +312,7 @@ struct HTTP1ConnectionStateMachine { return .wait case .modifying: - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") } } @@ -304,13 +329,13 @@ struct HTTP1ConnectionStateMachine { } case .modifying: - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") } } mutating func demandMoreResponseBodyParts() -> Action { guard case .inRequest(var requestStateMachine, let close) = self.state else { - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") } return self.avoidingStateMachineCoW { state -> Action in @@ -322,7 +347,7 @@ struct HTTP1ConnectionStateMachine { mutating func idleReadTimeoutTriggered() -> Action { guard case .inRequest(var requestStateMachine, let close) = self.state else { - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") } return self.avoidingStateMachineCoW { state -> Action in @@ -331,6 +356,29 @@ struct HTTP1ConnectionStateMachine { return state.modify(with: action) } } + + mutating func idleWriteTimeoutTriggered() -> Action { + guard case .inRequest(var requestStateMachine, let close) = self.state else { + return .wait + } + + return self.avoidingStateMachineCoW { state -> Action in + let action = requestStateMachine.idleWriteTimeoutTriggered() + state = .inRequest(requestStateMachine, close: close) + return state.modify(with: action) + } + } + + mutating func headSent() -> Action { + guard case .inRequest(var requestStateMachine, let close) = self.state else { + return .wait + } + return self.avoidingStateMachineCoW { state in + let action = requestStateMachine.headSent() + state = .inRequest(requestStateMachine, close: close) + return state.modify(with: action) + } + } } extension HTTP1ConnectionStateMachine { @@ -369,34 +417,41 @@ extension HTTP1ConnectionStateMachine { } extension HTTP1ConnectionStateMachine.State { - fileprivate mutating func modify(with action: HTTPRequestStateMachine.Action) -> HTTP1ConnectionStateMachine.Action { + fileprivate mutating func modify(with action: HTTPRequestStateMachine.Action) -> HTTP1ConnectionStateMachine.Action + { switch action { - case .sendRequestHead(let head, let startBody): - return .sendRequestHead(head, startBody: startBody) + case .sendRequestHead(let head, let sendEnd): + return .sendRequestHead(head, sendEnd: sendEnd) + case .notifyRequestHeadSendSuccessfully(let resumeRequestBodyStream, let startIdleTimer): + return .notifyRequestHeadSendSuccessfully( + resumeRequestBodyStream: resumeRequestBodyStream, + startIdleTimer: startIdleTimer + ) case .pauseRequestBodyStream: return .pauseRequestBodyStream case .resumeRequestBodyStream: return .resumeRequestBodyStream - case .sendBodyPart(let part): - return .sendBodyPart(part) - case .sendRequestEnd: - return .sendRequestEnd + case .sendBodyPart(let part, let writePromise): + return .sendBodyPart(part, writePromise) + case .sendRequestEnd(let writePromise): + return .sendRequestEnd(writePromise) case .forwardResponseHead(let head, let pauseRequestBodyStream): return .forwardResponseHead(head, pauseRequestBodyStream: pauseRequestBodyStream) case .forwardResponseBodyParts(let parts): return .forwardResponseBodyParts(parts) case .succeedRequest(let finalAction, let finalParts): guard case .inRequest(_, close: let close) = self else { - preconditionFailure("Invalid state: \(self)") + fatalError("Invalid state: \(self)") } - let newFinalAction: HTTP1ConnectionStateMachine.Action.FinalStreamAction + let newFinalAction: HTTP1ConnectionStateMachine.Action.FinalSuccessfulStreamAction switch finalAction { case .close: self = .closing newFinalAction = .close - case .sendRequestEnd: - newFinalAction = .sendRequestEnd + case .sendRequestEnd(let writePromise): + self = .idle + newFinalAction = .sendRequestEnd(writePromise, shouldClose: close) case .none: self = .idle newFinalAction = close ? .close : .informConnectionIsIdle @@ -406,13 +461,16 @@ extension HTTP1ConnectionStateMachine.State { case .failRequest(let error, let finalAction): switch self { case .initialized: - preconditionFailure("Invalid state: \(self)") + fatalError("Invalid state: \(self)") case .idle: - preconditionFailure("How can we fail a task, if we are idle") - case .inRequest(_, close: let close): - if close || finalAction == .close { + fatalError("How can we fail a task, if we are idle") + case .inRequest(_, let close): + if case .close(let promise) = finalAction { + self = .closing + return .failRequest(error, .close(promise)) + } else if close { self = .closing - return .failRequest(error, .close) + return .failRequest(error, .close(nil)) } else { self = .idle return .failRequest(error, .informConnectionIsIdle) @@ -425,7 +483,7 @@ extension HTTP1ConnectionStateMachine.State { return .failRequest(error, .none) case .modifying: - preconditionFailure("Invalid state: \(self)") + fatalError("Invalid state: \(self)") } case .read: @@ -433,6 +491,12 @@ extension HTTP1ConnectionStateMachine.State { case .wait: return .wait + + case .failSendBodyPart(let error, let writePromise): + return .failSendBodyPart(error, writePromise) + + case .failSendStreamFinished(let error, let writePromise): + return .failSendStreamFinished(error, writePromise) } } } @@ -444,14 +508,14 @@ extension HTTP1ConnectionStateMachine: CustomStringConvertible { return ".initialized" case .idle: return ".idle" - case .inRequest(let request, close: let close): + case .inRequest(let request, let close): return ".inRequest(\(request), closeAfterRequest: \(close))" case .closing: return ".closing" case .closed: return ".closed" case .modifying: - preconditionFailure("Invalid state: \(self.state)") + fatalError("Invalid state: \(self.state)") } } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift index 8b2a50738..7c0197cdf 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift @@ -35,8 +35,16 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { private var request: HTTPExecutableRequest? { didSet { - if let newRequest = self.request, let idleReadTimeout = newRequest.requestOptions.idleReadTimeout { - self.idleReadTimeoutStateMachine = .init(timeAmount: idleReadTimeout) + if let newRequest = self.request { + if let idleReadTimeout = newRequest.requestOptions.idleReadTimeout { + self.idleReadTimeoutStateMachine = .init(timeAmount: idleReadTimeout) + } + if let idleWriteTimeout = newRequest.requestOptions.idleWriteTimeout { + self.idleWriteTimeoutStateMachine = .init( + timeAmount: idleWriteTimeout, + isWritabilityEnabled: self.channelContext?.channel.isWritable ?? false + ) + } } else { self.idleReadTimeoutStateMachine = nil } @@ -46,13 +54,24 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { private var idleReadTimeoutStateMachine: IdleReadStateMachine? private var idleReadTimeoutTimer: Scheduled? + private var idleWriteTimeoutStateMachine: IdleWriteStateMachine? + private var idleWriteTimeoutTimer: Scheduled? + + /// Cancelling a task in NIO does *not* guarantee that the task will not execute under certain race conditions. + /// We therefore give each timer an ID and increase the ID every time we reset or cancel it. + /// We check in the task if the timer ID has changed in the meantime and do not execute any action if has changed. + private var currentIdleReadTimeoutTimerID: Int = 0 + private var currentIdleWriteTimeoutTimerID: Int = 0 + init(eventLoop: EventLoop) { self.eventLoop = eventLoop } func handlerAdded(context: ChannelHandlerContext) { - assert(context.eventLoop === self.eventLoop, - "The handler must be added to a channel that runs on the eventLoop it was initialized with.") + assert( + context.eventLoop === self.eventLoop, + "The handler must be added to a channel that runs on the eventLoop it was initialized with." + ) self.channelContext = context let isWritable = context.channel.isActive && context.channel.isWritable @@ -77,6 +96,10 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { } func channelWritabilityChanged(context: ChannelHandlerContext) { + if let timeoutAction = self.idleWriteTimeoutStateMachine?.channelWritabilityChanged(context: context) { + self.runTimeoutAction(timeoutAction, context: context) + } + let action = self.state.writabilityChanged(writable: context.channel.isWritable) self.run(action, context: context) } @@ -110,7 +133,11 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { // a single request. self.request = request - request.willExecuteRequest(self) + if let timeoutAction = self.idleWriteTimeoutStateMachine?.write() { + self.runTimeoutAction(timeoutAction, context: context) + } + + request.willExecuteRequest(self.requestExecutor) let action = self.state.startRequest( head: request.requestHead, @@ -140,22 +167,44 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { private func run(_ action: HTTPRequestStateMachine.Action, context: ChannelHandlerContext) { switch action { - case .sendRequestHead(let head, let startBody): - self.sendRequestHead(head, startBody: startBody, context: context) - + case .sendRequestHead(let head, let sendEnd): + self.sendRequestHead(head, sendEnd: sendEnd, context: context) + case .notifyRequestHeadSendSuccessfully(let resumeRequestBodyStream, let startIdleTimer): + // We can force unwrap the request here, as we have just validated in the state machine, + // that the request is neither failed nor finished yet + self.request!.requestHeadSent() + if resumeRequestBodyStream, let request = self.request { + // The above request head send notification might lead the request to mark itself as + // cancelled, which in turn might pop the request of the handler. For this reason we + // must check if the request is still present here. + request.resumeRequestBodyStream() + } + if startIdleTimer { + if let readTimeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { + self.runTimeoutAction(readTimeoutAction, context: context) + } + + if let writeTimeoutAction = self.idleWriteTimeoutStateMachine?.requestEndSent() { + self.runTimeoutAction(writeTimeoutAction, context: context) + } + } case .pauseRequestBodyStream: // We can force unwrap the request here, as we have just validated in the state machine, // that the request is neither failed nor finished yet self.request!.pauseRequestBodyStream() - case .sendBodyPart(let data): - context.writeAndFlush(self.wrapOutboundOut(.body(data)), promise: nil) + case .sendBodyPart(let data, let writePromise): + context.writeAndFlush(self.wrapOutboundOut(.body(data)), promise: writePromise) - case .sendRequestEnd: - context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + case .sendRequestEnd(let writePromise): + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: writePromise) - if let timeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { - self.runTimeoutAction(timeoutAction, context: context) + if let readTimeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { + self.runTimeoutAction(readTimeoutAction, context: context) + } + + if let writeTimeoutAction = self.idleWriteTimeoutStateMachine?.requestEndSent() { + self.runTimeoutAction(writeTimeoutAction, context: context) } case .read: @@ -169,7 +218,7 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { // that the request is neither failed nor finished yet self.request!.resumeRequestBodyStream() - case .forwardResponseHead(let head, pauseRequestBodyStream: let pauseRequestBodyStream): + case .forwardResponseHead(let head, let pauseRequestBodyStream): // We can force unwrap the request here, as we have just validated in the state machine, // that the request is neither failed nor finished yet self.request!.receiveResponseHead(head) @@ -185,17 +234,18 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { // that the request is neither failed nor finished yet self.request!.receiveResponseBodyParts(parts) - case .failRequest(let error, _): + case .failRequest(let error, let finalAction): // We can force unwrap the request here, as we have just validated in the state machine, // that the request object is still present. self.request!.fail(error) self.request = nil self.runTimeoutAction(.clearIdleReadTimeoutTimer, context: context) + self.runTimeoutAction(.clearIdleWriteTimeoutTimer, context: context) // No matter the error reason, we must always make sure the h2 stream is closed. Only // once the h2 stream is closed, it is released from the h2 multiplexer. The // HTTPRequestStateMachine may signal finalAction: .none in the error case (as this is // the right result for HTTP/1). In the h2 case we MUST always close. - self.runFinalAction(.close, context: context) + self.runFailedFinalAction(finalAction, context: context, error: error) case .succeedRequest(let finalAction, let finalParts): // We can force unwrap the request here, as we have just validated in the state machine, @@ -203,44 +253,54 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { self.request!.succeedRequest(finalParts) self.request = nil self.runTimeoutAction(.clearIdleReadTimeoutTimer, context: context) - self.runFinalAction(finalAction, context: context) + self.runTimeoutAction(.clearIdleWriteTimeoutTimer, context: context) + self.runSuccessfulFinalAction(finalAction, context: context) + + case .failSendBodyPart(let error, let writePromise), .failSendStreamFinished(let error, let writePromise): + writePromise?.fail(error) } } - private func sendRequestHead(_ head: HTTPRequestHead, startBody: Bool, context: ChannelHandlerContext) { - if startBody { - context.writeAndFlush(self.wrapOutboundOut(.head(head)), promise: nil) - - // The above write might trigger an error, which may lead to a call to `errorCaught`, - // which in turn, may fail the request and pop it from the handler. For this reason - // we must check if the request is still present here. - guard let request = self.request else { return } - request.requestHeadSent() - request.resumeRequestBodyStream() - } else { + private func sendRequestHead(_ head: HTTPRequestHead, sendEnd: Bool, context: ChannelHandlerContext) { + if sendEnd { context.write(self.wrapOutboundOut(.head(head)), promise: nil) context.write(self.wrapOutboundOut(.end(nil)), promise: nil) context.flush() + } else { + context.writeAndFlush(self.wrapOutboundOut(.head(head)), promise: nil) + } + self.run(self.state.headSent(), context: context) + } - // The above write might trigger an error, which may lead to a call to `errorCaught`, - // which in turn, may fail the request and pop it from the handler. For this reason - // we must check if the request is still present here. - guard let request = self.request else { return } - request.requestHeadSent() + private func runSuccessfulFinalAction( + _ action: HTTPRequestStateMachine.Action.FinalSuccessfulRequestAction, + context: ChannelHandlerContext + ) { + switch action { + case .close, .none: + // The actions returned here come from an `HTTPRequestStateMachine` that assumes http/1.1 + // semantics. For this reason we can ignore the close here, since an h2 stream is closed + // after every request anyway. + break - if let timeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { - self.runTimeoutAction(timeoutAction, context: context) - } + case .sendRequestEnd(let writePromise): + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: writePromise) } } - private func runFinalAction(_ action: HTTPRequestStateMachine.Action.FinalStreamAction, context: ChannelHandlerContext) { - switch action { - case .close: - context.close(promise: nil) + private func runFailedFinalAction( + _ action: HTTPRequestStateMachine.Action.FinalFailedRequestAction, + context: ChannelHandlerContext, + error: Error + ) { + // We must close the http2 stream after the request has finished. Since the request failed, + // we have no idea what the h2 streams state was. To be on the save side, we explicitly close + // the h2 stream. This will break a reference cycle in HTTP2Connection. + context.close(promise: nil) - case .sendRequestEnd: - context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + switch action { + case .close(let writePromise): + writePromise?.fail(error) case .none: break @@ -252,8 +312,9 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { case .startIdleReadTimeoutTimer(let timeAmount): assert(self.idleReadTimeoutTimer == nil, "Expected there is no timeout timer so far.") - self.idleReadTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { - guard self.idleReadTimeoutTimer != nil else { return } + let timerID = self.currentIdleReadTimeoutTimerID + self.idleReadTimeoutTimer = self.eventLoop.assumeIsolated().scheduleTask(in: timeAmount) { + guard self.currentIdleReadTimeoutTimerID == timerID else { return } let action = self.state.idleReadTimeoutTriggered() self.run(action, context: context) } @@ -263,17 +324,54 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { oldTimer.cancel() } - self.idleReadTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { - guard self.idleReadTimeoutTimer != nil else { return } + self.currentIdleReadTimeoutTimerID &+= 1 + let timerID = self.currentIdleReadTimeoutTimerID + self.idleReadTimeoutTimer = self.eventLoop.assumeIsolated().scheduleTask(in: timeAmount) { + guard self.currentIdleReadTimeoutTimerID == timerID else { return } let action = self.state.idleReadTimeoutTriggered() self.run(action, context: context) } case .clearIdleReadTimeoutTimer: if let oldTimer = self.idleReadTimeoutTimer { self.idleReadTimeoutTimer = nil + self.currentIdleReadTimeoutTimerID &+= 1 + oldTimer.cancel() + } + + case .none: + break + } + } + + private func runTimeoutAction(_ action: IdleWriteStateMachine.Action, context: ChannelHandlerContext) { + switch action { + case .startIdleWriteTimeoutTimer(let timeAmount): + assert(self.idleWriteTimeoutTimer == nil, "Expected there is no timeout timer so far.") + + let timerID = self.currentIdleWriteTimeoutTimerID + self.idleWriteTimeoutTimer = self.eventLoop.assumeIsolated().scheduleTask(in: timeAmount) { + guard self.currentIdleWriteTimeoutTimerID == timerID else { return } + let action = self.state.idleWriteTimeoutTriggered() + self.run(action, context: context) + } + case .resetIdleWriteTimeoutTimer(let timeAmount): + if let oldTimer = self.idleWriteTimeoutTimer { oldTimer.cancel() } + self.currentIdleWriteTimeoutTimerID &+= 1 + let timerID = self.currentIdleWriteTimeoutTimerID + self.idleWriteTimeoutTimer = self.eventLoop.assumeIsolated().scheduleTask(in: timeAmount) { + guard self.currentIdleWriteTimeoutTimerID == timerID else { return } + let action = self.state.idleWriteTimeoutTriggered() + self.run(action, context: context) + } + case .clearIdleWriteTimeoutTimer: + if let oldTimer = self.idleWriteTimeoutTimer { + self.idleWriteTimeoutTimer = nil + self.currentIdleWriteTimeoutTimerID &+= 1 + oldTimer.cancel() + } case .none: break } @@ -281,27 +379,33 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { // MARK: Private HTTPRequestExecutor - private func writeRequestBodyPart0(_ data: IOData, request: HTTPExecutableRequest) { + private func writeRequestBodyPart0(_ data: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise?) + { guard self.request === request, let context = self.channelContext else { // Because the HTTPExecutableRequest may run in a different thread to our eventLoop, // calls from the HTTPExecutableRequest to our ChannelHandler may arrive here after // the request has been popped by the state machine or the ChannelHandler has been // removed from the Channel pipeline. This is a normal threading issue, noone has // screwed up. + promise?.fail(HTTPClientError.requestStreamCancelled) return } - let action = self.state.requestStreamPartReceived(data) + if let timeoutAction = self.idleWriteTimeoutStateMachine?.write() { + self.runTimeoutAction(timeoutAction, context: context) + } + + let action = self.state.requestStreamPartReceived(data, promise: promise) self.run(action, context: context) } - private func finishRequestBodyStream0(_ request: HTTPExecutableRequest) { + private func finishRequestBodyStream0(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { guard self.request === request, let context = self.channelContext else { // See code comment in `writeRequestBodyPart0` return } - let action = self.state.requestStreamFinished() + let action = self.state.requestStreamFinished(promise: promise) self.run(action, context: context) } @@ -321,48 +425,51 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { return } + if let timeoutAction = self.idleWriteTimeoutStateMachine?.cancelRequest() { + self.runTimeoutAction(timeoutAction, context: context) + } + let action = self.state.requestCancelled() self.run(action, context: context) } } -extension HTTP2ClientRequestHandler: HTTPRequestExecutor { - func writeRequestBodyPart(_ data: IOData, request: HTTPExecutableRequest) { - if self.eventLoop.inEventLoop { - self.writeRequestBodyPart0(data, request: request) - } else { - self.eventLoop.execute { - self.writeRequestBodyPart0(data, request: request) +@available(*, unavailable) +extension HTTP2ClientRequestHandler: Sendable {} + +extension HTTP2ClientRequestHandler { + var requestExecutor: RequestExecutor { + RequestExecutor(self) + } + + struct RequestExecutor: HTTPRequestExecutor, Sendable { + private let loopBound: NIOLoopBound + + init(_ handler: HTTP2ClientRequestHandler) { + self.loopBound = NIOLoopBound(handler, eventLoop: handler.eventLoop) + } + + func writeRequestBodyPart(_ data: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise?) { + self.loopBound.execute { + $0.writeRequestBodyPart0(data, request: request, promise: promise) } } - } - func finishRequestBodyStream(_ request: HTTPExecutableRequest) { - if self.eventLoop.inEventLoop { - self.finishRequestBodyStream0(request) - } else { - self.eventLoop.execute { - self.finishRequestBodyStream0(request) + func finishRequestBodyStream(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { + self.loopBound.execute { + $0.finishRequestBodyStream0(request, promise: promise) } } - } - func demandResponseBodyStream(_ request: HTTPExecutableRequest) { - if self.eventLoop.inEventLoop { - self.demandResponseBodyStream0(request) - } else { - self.eventLoop.execute { - self.demandResponseBodyStream0(request) + func demandResponseBodyStream(_ request: HTTPExecutableRequest) { + self.loopBound.execute { + $0.demandResponseBodyStream0(request) } } - } - func cancelRequest(_ request: HTTPExecutableRequest) { - if self.eventLoop.inEventLoop { - self.cancelRequest0(request) - } else { - self.eventLoop.execute { - self.cancelRequest0(request) + func cancelRequest(_ request: HTTPExecutableRequest) { + self.loopBound.execute { + $0.cancelRequest0(request) } } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2Connection.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2Connection.swift index 8eb189adc..1c24554e2 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2Connection.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2Connection.swift @@ -15,12 +15,13 @@ import Logging import NIOCore import NIOHTTP2 +import NIOHTTPCompression -protocol HTTP2ConnectionDelegate { - func http2Connection(_: HTTP2Connection, newMaxStreamSetting: Int) - func http2ConnectionStreamClosed(_: HTTP2Connection, availableStreams: Int) - func http2ConnectionGoAwayReceived(_: HTTP2Connection) - func http2ConnectionClosed(_: HTTP2Connection) +protocol HTTP2ConnectionDelegate: Sendable { + func http2Connection(_: HTTPConnectionPool.Connection.ID, newMaxStreamSetting: Int) + func http2ConnectionStreamClosed(_: HTTPConnectionPool.Connection.ID, availableStreams: Int) + func http2ConnectionGoAwayReceived(_: HTTPConnectionPool.Connection.ID) + func http2ConnectionClosed(_: HTTPConnectionPool.Connection.ID) } struct HTTP2PushNotSupportedError: Error {} @@ -28,10 +29,15 @@ struct HTTP2PushNotSupportedError: Error {} struct HTTP2ReceivedGoAwayBeforeSettingsError: Error {} final class HTTP2Connection { + internal static let defaultSettings = nioDefaultSettings + [HTTP2Setting(parameter: .enablePush, value: 0)] + let channel: Channel let multiplexer: HTTP2StreamMultiplexer let logger: Logger + /// A method with access to the stream channel that is called when creating the stream. + let streamChannelDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? + /// the connection pool that created the connection let delegate: HTTP2ConnectionDelegate @@ -76,25 +82,34 @@ final class HTTP2Connection { /// We use this channel set to remember, which open streams we need to inform that /// we want to close the connection. The channels shall than cancel their currently running - /// request. + /// request. This property must only be accessed from the connections `EventLoop`. private var openStreams = Set() let id: HTTPConnectionPool.Connection.ID + let decompression: HTTPClient.Decompression + let maximumConnectionUses: Int? var closeFuture: EventLoopFuture { self.channel.closeFuture } - init(channel: Channel, - connectionID: HTTPConnectionPool.Connection.ID, - delegate: HTTP2ConnectionDelegate, - logger: Logger) { + init( + channel: Channel, + connectionID: HTTPConnectionPool.Connection.ID, + decompression: HTTPClient.Decompression, + maximumConnectionUses: Int?, + delegate: HTTP2ConnectionDelegate, + logger: Logger, + streamChannelDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? = nil + ) { self.channel = channel self.id = connectionID + self.decompression = decompression + self.maximumConnectionUses = maximumConnectionUses self.logger = logger self.multiplexer = HTTP2StreamMultiplexer( mode: .client, channel: channel, - targetWindowSize: 8 * 1024 * 1024, // 8mb + targetWindowSize: 8 * 1024 * 1024, // 8mb outboundBufferSizeHighWatermark: 8196, outboundBufferSizeLowWatermark: 4092, inboundStreamInitializer: { channel -> EventLoopFuture in @@ -103,11 +118,12 @@ final class HTTP2Connection { ) self.delegate = delegate self.state = .initialized + self.streamChannelDebugInitializer = streamChannelDebugInitializer } deinit { guard case .closed = self.state else { - preconditionFailure("Connection must be closed, before we can deinit it") + preconditionFailure("Connection must be closed, before we can deinit it. Current state: \(self.state)") } } @@ -115,54 +131,93 @@ final class HTTP2Connection { channel: Channel, connectionID: HTTPConnectionPool.Connection.ID, delegate: HTTP2ConnectionDelegate, - configuration: HTTPClient.Configuration, - logger: Logger - ) -> EventLoopFuture<(HTTP2Connection, Int)> { - let connection = HTTP2Connection(channel: channel, connectionID: connectionID, delegate: delegate, logger: logger) - return connection.start().map { maxStreams in (connection, maxStreams) } + decompression: HTTPClient.Decompression, + maximumConnectionUses: Int?, + logger: Logger, + streamChannelDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? = nil + ) -> EventLoopFuture<(HTTP2Connection, Int)>.Isolated { + let connection = HTTP2Connection( + channel: channel, + connectionID: connectionID, + decompression: decompression, + maximumConnectionUses: maximumConnectionUses, + delegate: delegate, + logger: logger, + streamChannelDebugInitializer: streamChannelDebugInitializer + ) + + return connection._start0().assumeIsolated().map { maxStreams in + (connection, maxStreams) + } } - func executeRequest(_ request: HTTPExecutableRequest) { - if self.channel.eventLoop.inEventLoop { - self.executeRequest0(request) - } else { - self.channel.eventLoop.execute { - self.executeRequest0(request) + var sendableView: SendableView { + SendableView(self) + } + + struct SendableView: Sendable { + private let connection: NIOLoopBound + let id: HTTPConnectionPool.Connection.ID + let channel: Channel + + var eventLoop: EventLoop { + self.connection.eventLoop + } + + var closeFuture: EventLoopFuture { + self.channel.closeFuture + } + + func __forTesting_getStreamChannels() -> [Channel] { + self.connection.value.__forTesting_getStreamChannels() + } + + init(_ connection: HTTP2Connection) { + self.connection = NIOLoopBound(connection, eventLoop: connection.channel.eventLoop) + self.id = connection.id + self.channel = connection.channel + } + + func executeRequest(_ request: HTTPExecutableRequest) { + self.connection.execute { + $0.executeRequest0(request) } } - } - /// shuts down the connection by cancelling all running tasks and closing the connection once - /// all child streams/channels are closed. - func shutdown() { - if self.channel.eventLoop.inEventLoop { - self.shutdown0() - } else { - self.channel.eventLoop.execute { - self.shutdown0() + func shutdown() { + self.connection.execute { + $0.shutdown0() } } - } - func close(promise: EventLoopPromise?) { - return self.channel.close(mode: .all, promise: promise) - } + func close(promise: EventLoopPromise?) { + self.channel.close(mode: .all, promise: promise) + } - func close() -> EventLoopFuture { - let promise = self.channel.eventLoop.makePromise(of: Void.self) - self.close(promise: promise) - return promise.futureResult + func close() -> EventLoopFuture { + let promise = self.eventLoop.makePromise(of: Void.self) + self.close(promise: promise) + return promise.futureResult + } } - private func start() -> EventLoopFuture { + func _start0() -> EventLoopFuture { self.channel.eventLoop.assertInEventLoop() let readyToAcceptConnectionsPromise = self.channel.eventLoop.makePromise(of: Int.self) self.state = .starting(readyToAcceptConnectionsPromise) - self.channel.closeFuture.whenComplete { _ in - self.state = .closed - self.delegate.http2ConnectionClosed(self) + self.channel.closeFuture.assumeIsolated().whenComplete { _ in + switch self.state { + case .initialized, .closed: + preconditionFailure("invalid state \(self.state)") + case .starting(let readyToAcceptConnectionsPromise): + self.state = .closed + readyToAcceptConnectionsPromise.fail(HTTPClientError.remoteConnectionClosed) + case .active, .closing: + self.state = .closed + self.delegate.http2ConnectionClosed(self.id) + } } do { @@ -173,8 +228,12 @@ final class HTTP2Connection { // can be scheduled on this connection. let sync = self.channel.pipeline.syncOperations - let http2Handler = NIOHTTP2Handler(mode: .client, initialSettings: nioDefaultSettings) - let idleHandler = HTTP2IdleHandler(delegate: self, logger: self.logger) + let http2Handler = NIOHTTP2Handler(mode: .client, initialSettings: Self.defaultSettings) + let idleHandler = HTTP2IdleHandler( + delegate: self, + logger: self.logger, + maximumConnectionUses: self.maximumConnectionUses + ) try sync.addHandler(http2Handler, position: .last) try sync.addHandler(idleHandler, position: .last) @@ -196,34 +255,51 @@ final class HTTP2Connection { case .active: let createStreamChannelPromise = self.channel.eventLoop.makePromise(of: Channel.self) - self.multiplexer.createStreamChannel(promise: createStreamChannelPromise) { channel -> EventLoopFuture in + let loopBoundSelf = NIOLoopBound(self, eventLoop: self.channel.eventLoop) + + self.multiplexer.createStreamChannel( + promise: createStreamChannelPromise + ) { [streamChannelDebugInitializer] channel -> EventLoopFuture in + let connection = loopBoundSelf.value + do { // the connection may have been asked to shutdown while we created the child. in // this // channel. - guard case .active = self.state else { + guard case .active = connection.state else { throw HTTPClientError.cancelled } // We only support http/2 over an https connection – using the Application-Layer // Protocol Negotiation (ALPN). For this reason it is safe to fix this to `.https`. let translate = HTTP2FramePayloadToHTTP1ClientCodec(httpProtocol: .https) - let handler = HTTP2ClientRequestHandler(eventLoop: channel.eventLoop) - try channel.pipeline.syncOperations.addHandler(translate) + + if case .enabled(let limit) = connection.decompression { + let decompressHandler = NIOHTTPResponseDecompressor(limit: limit) + try channel.pipeline.syncOperations.addHandler(decompressHandler) + } + + let handler = HTTP2ClientRequestHandler(eventLoop: channel.eventLoop) try channel.pipeline.syncOperations.addHandler(handler) // We must add the new channel to the list of open channels BEFORE we write the // request to it. In case of an error, we are sure that the channel was added // before. let box = ChannelBox(channel) - self.openStreams.insert(box) - self.channel.closeFuture.whenComplete { _ in - self.openStreams.remove(box) + connection.openStreams.insert(box) + channel.closeFuture.assumeIsolated().whenComplete { _ in + connection.openStreams.remove(box) } - channel.write(request, promise: nil) - return channel.eventLoop.makeSucceededVoidFuture() + if let streamChannelDebugInitializer = streamChannelDebugInitializer { + return streamChannelDebugInitializer(channel).map { _ in + channel.write(request, promise: nil) + } + } else { + channel.pipeline.syncOperations.write(NIOAny(request), promise: nil) + return channel.eventLoop.makeSucceededVoidFuture() + } } catch { return channel.eventLoop.makeFailedFuture(error) } @@ -243,16 +319,31 @@ final class HTTP2Connection { private func shutdown0() { self.channel.eventLoop.assertInEventLoop() - self.state = .closing + switch self.state { + case .active: + self.state = .closing + + // inform all open streams, that the currently running request should be cancelled. + for box in self.openStreams { + box.channel.triggerUserOutboundEvent(HTTPConnectionEvent.shutdownRequested, promise: nil) + } + + // inform the idle connection handler, that connection should be closed, once all streams + // are closed. + self.channel.triggerUserOutboundEvent(HTTPConnectionEvent.shutdownRequested, promise: nil) - // inform all open streams, that the currently running request should be cancelled. - self.openStreams.forEach { box in - box.channel.triggerUserOutboundEvent(HTTPConnectionEvent.shutdownRequested, promise: nil) + case .closed, .closing: + // we are already closing/closed and we need to tolerate this + break + + case .initialized, .starting: + preconditionFailure("invalid state \(self.state)") } + } - // inform the idle connection handler, that connection should be closed, once all streams - // are closed. - self.channel.triggerUserOutboundEvent(HTTPConnectionEvent.shutdownRequested, promise: nil) + func __forTesting_getStreamChannels() -> [Channel] { + self.channel.eventLoop.preconditionInEventLoop() + return self.openStreams.map { $0.channel } } } @@ -270,7 +361,7 @@ extension HTTP2Connection: HTTP2IdleHandlerDelegate { case .active: self.state = .active(maxStreams: maxStreams) - self.delegate.http2Connection(self, newMaxStreamSetting: maxStreams) + self.delegate.http2Connection(self.id, newMaxStreamSetting: maxStreams) case .closing, .closed: // ignore. we only wait for all connections to be closed anyway. @@ -291,7 +382,7 @@ extension HTTP2Connection: HTTP2IdleHandlerDelegate { case .active: self.state = .closing - self.delegate.http2ConnectionGoAwayReceived(self) + self.delegate.http2ConnectionGoAwayReceived(self.id) case .closing, .closed: // we are already closing. Nothing new @@ -302,6 +393,9 @@ extension HTTP2Connection: HTTP2IdleHandlerDelegate { func http2StreamClosed(availableStreams: Int) { self.channel.eventLoop.assertInEventLoop() - self.delegate.http2ConnectionStreamClosed(self, availableStreams: availableStreams) + self.delegate.http2ConnectionStreamClosed(self.id, availableStreams: availableStreams) } } + +@available(*, unavailable) +extension HTTP2Connection: Sendable {} diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2IdleHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2IdleHandler.swift index c522b2425..64a151489 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2IdleHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2IdleHandler.swift @@ -35,9 +35,10 @@ final class HTTP2IdleHandler: ChannelDuplexH let logger: Logger let delegate: Delegate - private var state: StateMachine = .init() + private var state: StateMachine - init(delegate: Delegate, logger: Logger) { + init(delegate: Delegate, logger: Logger, maximumConnectionUses: Int? = nil) { + self.state = StateMachine(maximumUses: maximumConnectionUses) self.delegate = delegate self.logger = logger } @@ -140,19 +141,23 @@ extension HTTP2IdleHandler { } enum State { - case initialized - case connected - case active(openStreams: Int, maxStreams: Int) + case initialized(maximumUses: Int?) + case connected(remainingUses: Int?) + case active(openStreams: Int, maxStreams: Int, remainingUses: Int?) case closing(openStreams: Int, maxStreams: Int) case closed } - var state: State = .initialized + var state: State + + init(maximumUses: Int?) { + self.state = .initialized(maximumUses: maximumUses) + } mutating func channelActive() { switch self.state { - case .initialized: - self.state = .connected + case .initialized(let maximumUses): + self.state = .connected(remainingUses: maximumUses) case .connected, .active, .closing, .closed: break @@ -171,17 +176,23 @@ extension HTTP2IdleHandler { case .initialized: preconditionFailure("Invalid state: \(self.state)") - case .connected: + case .connected(let remainingUses): // a settings frame might have multiple entries for `maxConcurrentStreams`. We are // only interested in the last value! If no `maxConcurrentStreams` is set, we assume // the http/2 default of 100. let maxStreams = settings.last(where: { $0.parameter == .maxConcurrentStreams })?.value ?? 100 - self.state = .active(openStreams: 0, maxStreams: maxStreams) + self.state = .active(openStreams: 0, maxStreams: maxStreams, remainingUses: remainingUses) return .notifyConnectionNewMaxStreamsSettings(maxStreams) - case .active(openStreams: let openStreams, maxStreams: let maxStreams): - if let newMaxStreams = settings.last(where: { $0.parameter == .maxConcurrentStreams })?.value, newMaxStreams != maxStreams { - self.state = .active(openStreams: openStreams, maxStreams: newMaxStreams) + case .active(let openStreams, let maxStreams, let remainingUses): + if let newMaxStreams = settings.last(where: { $0.parameter == .maxConcurrentStreams })?.value, + newMaxStreams != maxStreams + { + self.state = .active( + openStreams: openStreams, + maxStreams: newMaxStreams, + remainingUses: remainingUses + ) return .notifyConnectionNewMaxStreamsSettings(newMaxStreams) } return .nothing @@ -205,7 +216,7 @@ extension HTTP2IdleHandler { self.state = .closing(openStreams: 0, maxStreams: 0) return .notifyConnectionGoAwayReceived(close: true) - case .active(let openStreams, let maxStreams): + case .active(let openStreams, let maxStreams, _): self.state = .closing(openStreams: openStreams, maxStreams: maxStreams) return .notifyConnectionGoAwayReceived(close: openStreams == 0) @@ -228,7 +239,7 @@ extension HTTP2IdleHandler { self.state = .closing(openStreams: 0, maxStreams: 0) return .close - case .active(let openStreams, let maxStreams): + case .active(let openStreams, let maxStreams, _): if openStreams == 0 { self.state = .closed return .close @@ -247,10 +258,19 @@ extension HTTP2IdleHandler { case .initialized, .connected: preconditionFailure("Invalid state: \(self.state)") - case .active(var openStreams, let maxStreams): + case .active(var openStreams, let maxStreams, let remainingUses): openStreams += 1 - self.state = .active(openStreams: openStreams, maxStreams: maxStreams) - return .nothing + let remainingUses = remainingUses.map { $0 - 1 } + self.state = .active(openStreams: openStreams, maxStreams: maxStreams, remainingUses: remainingUses) + + if remainingUses == 0 { + // Treat running out of connection uses as if we received a GOAWAY frame. This + // will notify the delegate (i.e. connection pool) that the connection can no + // longer be used. + return self.goAwayReceived() + } else { + return .nothing + } case .closing(var openStreams, let maxStreams): // A stream might be opened, while we are closing because of race conditions. For @@ -271,10 +291,10 @@ extension HTTP2IdleHandler { case .initialized, .connected: preconditionFailure("Invalid state: \(self.state)") - case .active(var openStreams, let maxStreams): + case .active(var openStreams, let maxStreams, let remainingUses): openStreams -= 1 assert(openStreams >= 0) - self.state = .active(openStreams: openStreams, maxStreams: maxStreams) + self.state = .active(openStreams: openStreams, maxStreams: maxStreams, remainingUses: remainingUses) return .notifyConnectionStreamClosed(currentlyAvailable: maxStreams - openStreams) case .closing(var openStreams, let maxStreams): diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift index 4a3338697..3dc47c5ae 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift @@ -20,7 +20,9 @@ import NIOPosix import NIOSOCKS import NIOSSL import NIOTLS + #if canImport(Network) +import Network import NIOTransportServices #endif @@ -31,22 +33,26 @@ extension HTTPConnectionPool { let tlsConfiguration: TLSConfiguration let sslContextCache: SSLContextCache - init(key: ConnectionPool.Key, - tlsConfiguration: TLSConfiguration?, - clientConfiguration: HTTPClient.Configuration, - sslContextCache: SSLContextCache) { + init( + key: ConnectionPool.Key, + tlsConfiguration: TLSConfiguration?, + clientConfiguration: HTTPClient.Configuration, + sslContextCache: SSLContextCache + ) { self.key = key self.clientConfiguration = clientConfiguration self.sslContextCache = sslContextCache - self.tlsConfiguration = tlsConfiguration ?? clientConfiguration.tlsConfiguration ?? .makeClientConfiguration() + self.tlsConfiguration = + tlsConfiguration ?? clientConfiguration.tlsConfiguration ?? .makeClientConfiguration() } } } -protocol HTTPConnectionRequester { - func http1ConnectionCreated(_: HTTP1Connection) - func http2ConnectionCreated(_: HTTP2Connection, maximumStreams: Int) +protocol HTTPConnectionRequester: Sendable { + func http1ConnectionCreated(_: HTTP1Connection.SendableView) + func http2ConnectionCreated(_: HTTP2Connection.SendableView, maximumStreams: Int) func failedToCreateHTTPConnection(_: HTTPConnectionPool.Connection.ID, error: Error) + func waitingForConnectivity(_: HTTPConnectionPool.Connection.ID, error: Error) } extension HTTPConnectionPool.ConnectionFactory { @@ -62,7 +68,8 @@ extension HTTPConnectionPool.ConnectionFactory { var logger = logger logger[metadataKey: "ahc-connection-id"] = "\(connectionID)" - self.makeChannel(connectionID: connectionID, deadline: deadline, eventLoop: eventLoop, logger: logger).whenComplete { result in + let promise = eventLoop.makePromise(of: NegotiatedProtocol.self) + promise.futureResult.whenComplete { [logger] result in switch result { case .success(.http1_1(let channel)): do { @@ -70,10 +77,24 @@ extension HTTPConnectionPool.ConnectionFactory { channel: channel, connectionID: connectionID, delegate: http1ConnectionDelegate, - configuration: self.clientConfiguration, + decompression: self.clientConfiguration.decompression, logger: logger ) - requester.http1ConnectionCreated(connection) + + if let connectionDebugInitializer = self.clientConfiguration.http1_1ConnectionDebugInitializer { + connectionDebugInitializer(channel).hop( + to: eventLoop + ).assumeIsolated().whenComplete { debugInitializerResult in + switch debugInitializerResult { + case .success: + requester.http1ConnectionCreated(connection.sendableView) + case .failure(let error): + requester.failedToCreateHTTPConnection(connectionID, error: error) + } + } + } else { + requester.http1ConnectionCreated(connection.sendableView) + } } catch { requester.failedToCreateHTTPConnection(connectionID, error: error) } @@ -82,198 +103,256 @@ extension HTTPConnectionPool.ConnectionFactory { channel: channel, connectionID: connectionID, delegate: http2ConnectionDelegate, - configuration: self.clientConfiguration, - logger: logger + decompression: self.clientConfiguration.decompression, + maximumConnectionUses: self.clientConfiguration.maximumUsesPerConnection, + logger: logger, + streamChannelDebugInitializer: + self.clientConfiguration.http2StreamChannelDebugInitializer ).whenComplete { result in switch result { case .success((let connection, let maximumStreams)): - requester.http2ConnectionCreated(connection, maximumStreams: maximumStreams) + if let connectionDebugInitializer = self.clientConfiguration.http2ConnectionDebugInitializer { + connectionDebugInitializer(channel).hop(to: eventLoop).assumeIsolated().whenComplete { + debugInitializerResult in + switch debugInitializerResult { + case .success: + requester.http2ConnectionCreated( + connection.sendableView, + maximumStreams: maximumStreams + ) + case .failure(let error): + requester.failedToCreateHTTPConnection( + connectionID, + error: error + ) + } + } + } else { + requester.http2ConnectionCreated( + connection.sendableView, + maximumStreams: maximumStreams + ) + } case .failure(let error): requester.failedToCreateHTTPConnection(connectionID, error: error) } } - case .failure(let error): + case .failure(var error): + // let's map `ChannelError.connectTimeout` into a `HTTPClientError.connectTimeout` + switch error { + case ChannelError.connectTimeout: + error = HTTPClientError.connectTimeout + default: + () + } requester.failedToCreateHTTPConnection(connectionID, error: error) } } - } - enum NegotiatedProtocol { - case http1_1(Channel) - case http2(Channel) - } - - func makeHTTP1Channel( - connectionID: HTTPConnectionPool.Connection.ID, - deadline: NIODeadline, - eventLoop: EventLoop, - logger: Logger - ) -> EventLoopFuture { self.makeChannel( + requester: requester, connectionID: connectionID, deadline: deadline, eventLoop: eventLoop, - logger: logger - ).flatMapThrowing { negotiated -> Channel in - - guard case .http1_1(let channel) = negotiated else { - preconditionFailure("Expected to create http/1.1 connections only for now") - } - - // add the http1.1 channel handlers - let syncOperations = channel.pipeline.syncOperations - try syncOperations.addHTTPClientHandlers(leftOverBytesStrategy: .forwardBytes) - - switch self.clientConfiguration.decompression { - case .disabled: - () - case .enabled(let limit): - let decompressHandler = NIOHTTPResponseDecompressor(limit: limit) - try syncOperations.addHandler(decompressHandler) - } + logger: logger, + promise: promise + ) + } - return channel - } + enum NegotiatedProtocol { + case http1_1(Channel) + case http2(Channel) } - func makeChannel( + func makeChannel( + requester: Requester, connectionID: HTTPConnectionPool.Connection.ID, deadline: NIODeadline, eventLoop: EventLoop, - logger: Logger - ) -> EventLoopFuture { - let channelFuture: EventLoopFuture - + logger: Logger, + promise: EventLoopPromise + ) { if self.key.scheme.isProxyable, let proxy = self.clientConfiguration.proxy { switch proxy.type { case .socks: - channelFuture = self.makeSOCKSProxyChannel( + self.makeSOCKSProxyChannel( proxy, + requester: requester, connectionID: connectionID, deadline: deadline, eventLoop: eventLoop, - logger: logger + logger: logger, + promise: promise ) case .http: - channelFuture = self.makeHTTPProxyChannel( + self.makeHTTPProxyChannel( proxy, + requester: requester, connectionID: connectionID, deadline: deadline, eventLoop: eventLoop, - logger: logger + logger: logger, + promise: promise ) } } else { - channelFuture = self.makeNonProxiedChannel(deadline: deadline, eventLoop: eventLoop, logger: logger) - } - - // let's map `ChannelError.connectTimeout` into a `HTTPClientError.connectTimeout` - return channelFuture.flatMapErrorThrowing { error throws -> NegotiatedProtocol in - switch error { - case ChannelError.connectTimeout: - throw HTTPClientError.connectTimeout - default: - throw error - } + self.makeNonProxiedChannel( + requester: requester, + connectionID: connectionID, + deadline: deadline, + eventLoop: eventLoop, + logger: logger, + promise: promise + ) } } - private func makeNonProxiedChannel( + private func makeNonProxiedChannel( + requester: Requester, + connectionID: HTTPConnectionPool.Connection.ID, deadline: NIODeadline, eventLoop: EventLoop, - logger: Logger - ) -> EventLoopFuture { + logger: Logger, + promise: EventLoopPromise + ) { switch self.key.scheme { case .http, .httpUnix, .unix: - return self.makePlainChannel(deadline: deadline, eventLoop: eventLoop).map { .http1_1($0) } + self.makePlainChannel( + requester: requester, + connectionID: connectionID, + deadline: deadline, + eventLoop: eventLoop, + promise: promise + ) case .https, .httpsUnix: - return self.makeTLSChannel(deadline: deadline, eventLoop: eventLoop, logger: logger).flatMapThrowing { - channel, negotiated in - - try self.matchALPNToHTTPVersion(negotiated, channel: channel) - } + self.makeTLSChannel( + requester: requester, + connectionID: connectionID, + deadline: deadline, + eventLoop: eventLoop, + logger: logger, + promise: promise + ) } } - private func makePlainChannel(deadline: NIODeadline, eventLoop: EventLoop) -> EventLoopFuture { + private func makePlainChannel( + requester: Requester, + connectionID: HTTPConnectionPool.Connection.ID, + deadline: NIODeadline, + eventLoop: EventLoop, + promise: EventLoopPromise + ) { precondition(!self.key.scheme.usesTLS, "Unexpected scheme") - return self.makePlainBootstrap(deadline: deadline, eventLoop: eventLoop).connect(target: self.key.connectionTarget) + return self.makePlainBootstrap( + requester: requester, + connectionID: connectionID, + deadline: deadline, + eventLoop: eventLoop + ).connect(target: self.key.connectionTarget).map { + .http1_1($0) + }.cascade(to: promise) } - private func makeHTTPProxyChannel( + private func makeHTTPProxyChannel( _ proxy: HTTPClient.Configuration.Proxy, + requester: Requester, connectionID: HTTPConnectionPool.Connection.ID, deadline: NIODeadline, eventLoop: EventLoop, - logger: Logger - ) -> EventLoopFuture { + logger: Logger, + promise: EventLoopPromise + ) { // A proxy connection starts with a plain text connection to the proxy server. After // the connection has been established with the proxy server, the connection might be // upgraded to TLS before we send our first request. - let bootstrap = self.makePlainBootstrap(deadline: deadline, eventLoop: eventLoop) - return bootstrap.connect(host: proxy.host, port: proxy.port).flatMap { channel in - let encoder = HTTPRequestEncoder() - let decoder = ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .dropBytes)) - let proxyHandler = HTTP1ProxyConnectHandler( - target: self.key.connectionTarget, - proxyAuthorization: proxy.authorization, - deadline: deadline - ) - - do { - try channel.pipeline.syncOperations.addHandler(encoder) - try channel.pipeline.syncOperations.addHandler(decoder) - try channel.pipeline.syncOperations.addHandler(proxyHandler) - } catch { - return channel.eventLoop.makeFailedFuture(error) - } + let bootstrap = self.makePlainBootstrap( + requester: requester, + connectionID: connectionID, + deadline: deadline, + eventLoop: eventLoop + ) + bootstrap.connect(host: proxy.host, port: proxy.port).whenComplete { result in + switch result { + case .success(let channel): + let encoder = HTTPRequestEncoder() + let decoder = ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .dropBytes)) + let proxyHandler = HTTP1ProxyConnectHandler( + target: self.key.connectionTarget, + proxyAuthorization: proxy.authorization, + deadline: deadline + ) - // The proxyEstablishedFuture is set as soon as the HTTP1ProxyConnectHandler is in a - // pipeline. It is created in HTTP1ProxyConnectHandler's handlerAdded method. - return proxyHandler.proxyEstablishedFuture!.flatMap { - channel.pipeline.removeHandler(proxyHandler).flatMap { - channel.pipeline.removeHandler(decoder).flatMap { - channel.pipeline.removeHandler(encoder) - } + do { + try channel.pipeline.syncOperations.addHandler(encoder) + try channel.pipeline.syncOperations.addHandler(decoder) + try channel.pipeline.syncOperations.addHandler(proxyHandler) + } catch { + return promise.fail(error) } - }.flatMap { - self.setupTLSInProxyConnectionIfNeeded(channel, deadline: deadline, logger: logger) + + // The proxyEstablishedFuture is set as soon as the HTTP1ProxyConnectHandler is in a + // pipeline. It is created in HTTP1ProxyConnectHandler's handlerAdded method. + return proxyHandler.proxyEstablishedFuture!.assumeIsolated().flatMap { + channel.pipeline.syncOperations.removeHandler(proxyHandler).assumeIsolated().flatMap { + channel.pipeline.syncOperations.removeHandler(decoder).assumeIsolated().flatMap { + channel.pipeline.syncOperations.removeHandler(encoder) + }.nonisolated() + }.nonisolated() + }.flatMap { + self.setupTLSInProxyConnectionIfNeeded(channel, deadline: deadline, logger: logger) + }.nonisolated().cascade(to: promise) + case .failure(let error): + promise.fail(error) } } } - private func makeSOCKSProxyChannel( + private func makeSOCKSProxyChannel( _ proxy: HTTPClient.Configuration.Proxy, + requester: Requester, connectionID: HTTPConnectionPool.Connection.ID, deadline: NIODeadline, eventLoop: EventLoop, - logger: Logger - ) -> EventLoopFuture { + logger: Logger, + promise: EventLoopPromise + ) { // A proxy connection starts with a plain text connection to the proxy server. After // the connection has been established with the proxy server, the connection might be // upgraded to TLS before we send our first request. - let bootstrap = self.makePlainBootstrap(deadline: deadline, eventLoop: eventLoop) - return bootstrap.connect(host: proxy.host, port: proxy.port).flatMap { channel in - let socksConnectHandler = SOCKSClientHandler(targetAddress: SOCKSAddress(self.key.connectionTarget)) - let socksEventHandler = SOCKSEventsHandler(deadline: deadline) - - do { - try channel.pipeline.syncOperations.addHandler(socksConnectHandler) - try channel.pipeline.syncOperations.addHandler(socksEventHandler) - } catch { - return channel.eventLoop.makeFailedFuture(error) - } + let bootstrap = self.makePlainBootstrap( + requester: requester, + connectionID: connectionID, + deadline: deadline, + eventLoop: eventLoop + ) + bootstrap.connect(host: proxy.host, port: proxy.port).whenComplete { result in + switch result { + case .success(let channel): + let socksConnectHandler = SOCKSClientHandler(targetAddress: SOCKSAddress(self.key.connectionTarget)) + let socksEventHandler = SOCKSEventsHandler(deadline: deadline) - // The socksEstablishedFuture is set as soon as the SOCKSEventsHandler is in a - // pipeline. It is created in SOCKSEventsHandler's handlerAdded method. - return socksEventHandler.socksEstablishedFuture!.flatMap { - channel.pipeline.removeHandler(socksEventHandler).flatMap { - channel.pipeline.removeHandler(socksConnectHandler) + do { + try channel.pipeline.syncOperations.addHandler(socksConnectHandler) + try channel.pipeline.syncOperations.addHandler(socksEventHandler) + } catch { + return promise.fail(error) } - }.flatMap { - self.setupTLSInProxyConnectionIfNeeded(channel, deadline: deadline, logger: logger) + + // The socksEstablishedFuture is set as soon as the SOCKSEventsHandler is in a + // pipeline. It is created in SOCKSEventsHandler's handlerAdded method. + socksEventHandler.socksEstablishedFuture!.assumeIsolated().flatMap { + channel.pipeline.syncOperations.removeHandler(socksEventHandler).assumeIsolated().flatMap { + channel.pipeline.syncOperations.removeHandler(socksConnectHandler) + }.nonisolated() + }.flatMap { + self.setupTLSInProxyConnectionIfNeeded(channel, deadline: deadline, logger: logger) + }.nonisolated().cascade(to: promise) + case .failure(let error): + promise.fail(error) } + } } @@ -299,9 +378,8 @@ extension HTTPConnectionPool.ConnectionFactory { case .http1Only: tlsConfig.applicationProtocols = ["http/1.1"] } - let tlsEventHandler = TLSEventsHandler(deadline: deadline) - let sslServerHostname = self.key.connectionTarget.sslServerHostname + let sslServerHostname = self.key.serverNameIndicator let sslContextFuture = self.sslContextCache.sslContext( tlsConfiguration: tlsConfig, eventLoop: channel.eventLoop, @@ -315,6 +393,7 @@ extension HTTPConnectionPool.ConnectionFactory { serverHostname: sslServerHostname ) try channel.pipeline.syncOperations.addHandler(sslHandler) + let tlsEventHandler = TLSEventsHandler(deadline: deadline) try channel.pipeline.syncOperations.addHandler(tlsEventHandler) // The tlsEstablishedFuture is set as soon as the TLSEventsHandler is in a @@ -324,21 +403,46 @@ extension HTTPConnectionPool.ConnectionFactory { return channel.eventLoop.makeFailedFuture(error) } }.flatMap { negotiated -> EventLoopFuture in - channel.pipeline.removeHandler(tlsEventHandler).flatMapThrowing { - try self.matchALPNToHTTPVersion(negotiated, channel: channel) + do { + let sync = channel.pipeline.syncOperations + let context = try sync.context(handlerType: TLSEventsHandler.self) + return sync.removeHandler(context: context).flatMapThrowing { + try Self.matchALPNToHTTPVersion(negotiated, channel: channel) + } + } catch { + return channel.eventLoop.makeFailedFuture(error) } } } } - private func makePlainBootstrap(deadline: NIODeadline, eventLoop: EventLoop) -> NIOClientTCPBootstrapProtocol { + private func makePlainBootstrap( + requester: Requester, + connectionID: HTTPConnectionPool.Connection.ID, + deadline: NIODeadline, + eventLoop: EventLoop + ) -> NIOClientTCPBootstrapProtocol { #if canImport(Network) - if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *), let tsBootstrap = NIOTSConnectionBootstrap(validatingGroup: eventLoop) { - return tsBootstrap + if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *), + let tsBootstrap = NIOTSConnectionBootstrap(validatingGroup: eventLoop) + { + return + tsBootstrap + .channelOption( + NIOTSChannelOptions.waitForActivity, + value: self.clientConfiguration.networkFrameworkWaitForConnectivity + ) + .channelOption( + NIOTSChannelOptions.multipathServiceType, + value: self.clientConfiguration.enableMultipath ? .handover : .disabled + ) .connectTimeout(deadline - NIODeadline.now()) .channelInitializer { channel in do { try channel.pipeline.syncOperations.addHandler(HTTPClient.NWErrorHandler()) + try channel.pipeline.syncOperations.addHandler( + NWWaitingHandler(requester: requester, connectionID: connectionID) + ) return channel.eventLoop.makeSucceededVoidFuture() } catch { return channel.eventLoop.makeFailedFuture(error) @@ -348,47 +452,77 @@ extension HTTPConnectionPool.ConnectionFactory { #endif if let nioBootstrap = ClientBootstrap(validatingGroup: eventLoop) { - return nioBootstrap + return + nioBootstrap .connectTimeout(deadline - NIODeadline.now()) + .enableMPTCP(clientConfiguration.enableMultipath) } preconditionFailure("No matching bootstrap found") } - private func makeTLSChannel(deadline: NIODeadline, eventLoop: EventLoop, logger: Logger) -> EventLoopFuture<(Channel, String?)> { + private func makeTLSChannel( + requester: Requester, + connectionID: HTTPConnectionPool.Connection.ID, + deadline: NIODeadline, + eventLoop: EventLoop, + logger: Logger, + promise: EventLoopPromise + ) { precondition(self.key.scheme.usesTLS, "Unexpected scheme") let bootstrapFuture = self.makeTLSBootstrap( + requester: requester, + connectionID: connectionID, deadline: deadline, eventLoop: eventLoop, logger: logger ) - var channelFuture = bootstrapFuture.flatMap { bootstrap -> EventLoopFuture in - return bootstrap.connect(target: self.key.connectionTarget) - }.flatMap { channel -> EventLoopFuture<(Channel, String?)> in - // It is save to use `try!` here, since we are sure, that a `TLSEventsHandler` exists - // within the pipeline. It is added in `makeTLSBootstrap`. - let tlsEventHandler = try! channel.pipeline.syncOperations.handler(type: TLSEventsHandler.self) - - // The tlsEstablishedFuture is set as soon as the TLSEventsHandler is in a - // pipeline. It is created in TLSEventsHandler's handlerAdded method. - return tlsEventHandler.tlsEstablishedFuture!.flatMap { negotiated in - channel.pipeline.removeHandler(tlsEventHandler).map { (channel, negotiated) } + bootstrapFuture.whenComplete { result in + switch result { + case .success(let bootstrap): + bootstrap.connect(target: self.key.connectionTarget).flatMap { + channel -> EventLoopFuture<(Channel, String?)> in + do { + // if the channel is closed before flatMap is executed, all ChannelHandler are removed + // and TLSEventsHandler is therefore not present either + let tlsEventHandler = try channel.pipeline.syncOperations.handler(type: TLSEventsHandler.self) + + // The tlsEstablishedFuture is set as soon as the TLSEventsHandler is in a + // pipeline. It is created in TLSEventsHandler's handlerAdded method. + return tlsEventHandler.tlsEstablishedFuture!.assumeIsolated().flatMap { negotiated in + channel.pipeline.syncOperations.removeHandler(tlsEventHandler).map { (channel, negotiated) } + }.nonisolated() + } catch { + assert( + channel.isActive == false, + "if the channel is still active then TLSEventsHandler must be present but got error \(error)" + ) + return channel.eventLoop.makeFailedFuture(HTTPClientError.remoteConnectionClosed) + } + }.flatMapThrowing { channel, alpn in + try Self.matchALPNToHTTPVersion(alpn, channel: channel) + }.flatMapErrorThrowing { error in + // If NIOTransportSecurity is used, we want to map NWErrors into NWPOsixErrors or NWTLSError. + #if canImport(Network) + throw HTTPClient.NWErrorHandler.translateError(error) + #else + throw error + #endif + }.cascade(to: promise) + case .failure(let error): + promise.fail(error) } } - - #if canImport(Network) - // If NIOTransportSecurity is used, we want to map NWErrors into NWPOsixErrors or NWTLSError. - channelFuture = channelFuture.flatMapErrorThrowing { error in - throw HTTPClient.NWErrorHandler.translateError(error) - } - #endif - - return channelFuture } - private func makeTLSBootstrap(deadline: NIODeadline, eventLoop: EventLoop, logger: Logger) - -> EventLoopFuture { + private func makeTLSBootstrap( + requester: Requester, + connectionID: HTTPConnectionPool.Connection.ID, + deadline: NIODeadline, + eventLoop: EventLoop, + logger: Logger + ) -> EventLoopFuture { var tlsConfig = self.tlsConfiguration switch self.clientConfiguration.httpVersion.configuration { case .automatic: @@ -402,17 +536,31 @@ extension HTTPConnectionPool.ConnectionFactory { } #if canImport(Network) - if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *), let tsBootstrap = NIOTSConnectionBootstrap(validatingGroup: eventLoop) { + if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *), eventLoop is QoSEventLoop { // create NIOClientTCPBootstrap with NIOTS TLS provider - let bootstrapFuture = tlsConfig.getNWProtocolTLSOptions(on: eventLoop).map { + let bootstrapFuture = tlsConfig.getNWProtocolTLSOptions( + on: eventLoop, + serverNameIndicatorOverride: key.serverNameIndicatorOverride + ).map { options -> NIOClientTCPBootstrapProtocol in - tsBootstrap + NIOTSConnectionBootstrap(group: eventLoop) // validated above + .channelOption( + NIOTSChannelOptions.waitForActivity, + value: self.clientConfiguration.networkFrameworkWaitForConnectivity + ) + .channelOption( + NIOTSChannelOptions.multipathServiceType, + value: self.clientConfiguration.enableMultipath ? .handover : .disabled + ) .connectTimeout(deadline - NIODeadline.now()) .tlsOptions(options) .channelInitializer { channel in do { try channel.pipeline.syncOperations.addHandler(HTTPClient.NWErrorHandler()) + try channel.pipeline.syncOperations.addHandler( + NWWaitingHandler(requester: requester, connectionID: connectionID) + ) // we don't need to set a TLS deadline for NIOTS connections, since the // TLS handshake is part of the TS connection bootstrap. If the TLS // handshake times out the complete connection creation will be failed. @@ -427,38 +575,38 @@ extension HTTPConnectionPool.ConnectionFactory { } #endif - let sslServerHostname = self.key.connectionTarget.sslServerHostname let sslContextFuture = sslContextCache.sslContext( tlsConfiguration: tlsConfig, eventLoop: eventLoop, logger: logger ) - let bootstrap = ClientBootstrap(group: eventLoop) - .connectTimeout(deadline - NIODeadline.now()) - .channelInitializer { channel in - sslContextFuture.flatMap { sslContext -> EventLoopFuture in - do { - let sync = channel.pipeline.syncOperations - let sslHandler = try NIOSSLClientHandler( - context: sslContext, - serverHostname: sslServerHostname - ) - let tlsEventHandler = TLSEventsHandler(deadline: deadline) - - try sync.addHandler(sslHandler) - try sync.addHandler(tlsEventHandler) - return channel.eventLoop.makeSucceededVoidFuture() - } catch { - return channel.eventLoop.makeFailedFuture(error) + return eventLoop.submit { + ClientBootstrap(group: eventLoop) + .connectTimeout(deadline - NIODeadline.now()) + .enableMPTCP(clientConfiguration.enableMultipath) + .channelInitializer { channel in + sslContextFuture.flatMap { sslContext -> EventLoopFuture in + do { + let sync = channel.pipeline.syncOperations + let sslHandler = try NIOSSLClientHandler( + context: sslContext, + serverHostname: self.key.serverNameIndicator + ) + let tlsEventHandler = TLSEventsHandler(deadline: deadline) + + try sync.addHandler(sslHandler) + try sync.addHandler(tlsEventHandler) + return channel.eventLoop.makeSucceededVoidFuture() + } catch { + return channel.eventLoop.makeFailedFuture(error) + } } } - } - - return eventLoop.makeSucceededFuture(bootstrap) + } } - private func matchALPNToHTTPVersion(_ negotiated: String?, channel: Channel) throws -> NegotiatedProtocol { + private static func matchALPNToHTTPVersion(_ negotiated: String?, channel: Channel) throws -> NegotiatedProtocol { switch negotiated { case .none, .some("http/1.1"): return .http1_1(channel) @@ -481,6 +629,12 @@ extension Scheme { } } +extension ConnectionPool.Key { + var serverNameIndicator: String? { + serverNameIndicatorOverride ?? connectionTarget.sslServerHostname + } +} + extension ConnectionTarget { fileprivate var sslServerHostname: String? { switch self { diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Manager.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Manager.swift index 1a1760908..4c313e92b 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Manager.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Manager.swift @@ -12,35 +12,40 @@ // //===----------------------------------------------------------------------===// +import Atomics import Logging import NIOConcurrencyHelpers import NIOCore import NIOHTTP1 extension HTTPConnectionPool { - final class Manager { + final class Manager: Sendable { private typealias Key = ConnectionPool.Key - private enum State { + private enum RunState: Sendable { case active case shuttingDown(promise: EventLoopPromise?, unclean: Bool) case shutDown } + private struct State: Sendable { + var runState: RunState = .active + var pools: [Key: HTTPConnectionPool] = [:] + } + private let eventLoopGroup: EventLoopGroup private let configuration: HTTPClient.Configuration private let connectionIDGenerator = Connection.ID.globalGenerator private let logger: Logger - private var state: State = .active - private var _pools: [Key: HTTPConnectionPool] = [:] - private let lock = Lock() - + private let state: NIOLockedValueBox = NIOLockedValueBox(State()) private let sslContextCache = SSLContextCache() - init(eventLoopGroup: EventLoopGroup, - configuration: HTTPClient.Configuration, - backgroundActivityLogger logger: Logger) { + init( + eventLoopGroup: EventLoopGroup, + configuration: HTTPClient.Configuration, + backgroundActivityLogger logger: Logger + ) { self.eventLoopGroup = eventLoopGroup self.configuration = configuration self.logger = logger @@ -48,10 +53,10 @@ extension HTTPConnectionPool { func executeRequest(_ request: HTTPSchedulableRequest) { let poolKey = request.poolKey - let poolResult = self.lock.withLock { () -> Result in - switch self.state { + let poolResult = self.state.withLockedValue { state -> Result in + switch state.runState { case .active: - if let pool = self._pools[poolKey] { + if let pool = state.pools[poolKey] { return .success(pool) } @@ -65,7 +70,7 @@ extension HTTPConnectionPool { idGenerator: self.connectionIDGenerator, backgroundActivityLogger: self.logger ) - self._pools[poolKey] = pool + state.pools[poolKey] = pool return .success(pool) case .shuttingDown, .shutDown: @@ -92,17 +97,17 @@ extension HTTPConnectionPool { case shutdown([Key: HTTPConnectionPool]) } - let action = self.lock.withLock { () -> ShutdownAction in - switch self.state { + let action = self.state.withLockedValue { state -> ShutdownAction in + switch state.runState { case .active: // If there aren't any pools, we can mark the pool as shut down right away. - if self._pools.isEmpty { - self.state = .shutDown + if state.pools.isEmpty { + state.runState = .shutDown return .done(promise) } else { // this promise will be succeeded once all connection pools are shutdown - self.state = .shuttingDown(promise: promise, unclean: false) - return .shutdown(self._pools) + state.runState = .shuttingDown(promise: promise, unclean: false) + return .shutdown(state.pools) } case .shuttingDown, .shutDown: @@ -117,7 +122,7 @@ extension HTTPConnectionPool { promise?.succeed(false) case .shutdown(let pools): - pools.values.forEach { pool in + for pool in pools.values { pool.shutdown() } } @@ -132,28 +137,30 @@ extension HTTPConnectionPool.Manager: HTTPConnectionPoolDelegate { case wait } - let closeAction = self.lock.withLock { () -> CloseAction in - switch self.state { + let closeAction = self.state.withLockedValue { state -> CloseAction in + switch state.runState { case .active, .shutDown: preconditionFailure("Why are pools shutting down, if the manager did not give a signal") case .shuttingDown(let promise, let soFarUnclean): - guard self._pools.removeValue(forKey: pool.key) === pool else { - preconditionFailure("Expected that the pool was created by this manager and is known for this reason.") + guard state.pools.removeValue(forKey: pool.key) === pool else { + preconditionFailure( + "Expected that the pool was created by this manager and is known for this reason." + ) } - if self._pools.isEmpty { - self.state = .shutDown + if state.pools.isEmpty { + state.runState = .shutDown return .close(promise, unclean: soFarUnclean || unclean) } else { - self.state = .shuttingDown(promise: promise, unclean: soFarUnclean || unclean) + state.runState = .shuttingDown(promise: promise, unclean: soFarUnclean || unclean) return .wait } } } switch closeAction { - case .close(let promise, unclean: let unclean): + case .close(let promise, let unclean): promise?.succeed(unclean) case .wait: break @@ -162,17 +169,17 @@ extension HTTPConnectionPool.Manager: HTTPConnectionPoolDelegate { } extension HTTPConnectionPool.Connection.ID { - static var globalGenerator = Generator() + static let globalGenerator = Generator() struct Generator { - private let atomic: NIOAtomic + private let atomic: ManagedAtomic init() { - self.atomic = .makeAtomic(value: 0) + self.atomic = .init(0) } func next() -> Int { - return self.atomic.add(1) + self.atomic.loadThenWrappingIncrement(ordering: .relaxed) } } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool.swift index 764ad2093..676df915a 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool.swift @@ -21,8 +21,11 @@ protocol HTTPConnectionPoolDelegate { func connectionPoolDidShutdown(_ pool: HTTPConnectionPool, unclean: Bool) } -final class HTTPConnectionPool { - private let stateLock = Lock() +final class HTTPConnectionPool: + // TODO: Refactor to use `NIOLockedValueBox` which will allow this to be checked + @unchecked Sendable +{ + private let stateLock = NIOLock() private var _state: StateMachine /// The connection idle timeout timers. Protected by the stateLock private var _idleTimer = [Connection.ID: Scheduled]() @@ -44,14 +47,16 @@ final class HTTPConnectionPool { let delegate: HTTPConnectionPoolDelegate - init(eventLoopGroup: EventLoopGroup, - sslContextCache: SSLContextCache, - tlsConfiguration: TLSConfiguration?, - clientConfiguration: HTTPClient.Configuration, - key: ConnectionPool.Key, - delegate: HTTPConnectionPoolDelegate, - idGenerator: Connection.ID.Generator, - backgroundActivityLogger logger: Logger) { + init( + eventLoopGroup: EventLoopGroup, + sslContextCache: SSLContextCache, + tlsConfiguration: TLSConfiguration?, + clientConfiguration: HTTPClient.Configuration, + key: ConnectionPool.Key, + delegate: HTTPConnectionPoolDelegate, + idGenerator: Connection.ID.Generator, + backgroundActivityLogger logger: Logger + ) { self.eventLoopGroup = eventLoopGroup self.connectionFactory = ConnectionFactory( key: key, @@ -70,7 +75,12 @@ final class HTTPConnectionPool { self._state = StateMachine( idGenerator: idGenerator, - maximumConcurrentHTTP1Connections: clientConfiguration.connectionPool.concurrentHTTP1ConnectionsPerHostSoftLimit + maximumConcurrentHTTP1Connections: clientConfiguration.connectionPool + .concurrentHTTP1ConnectionsPerHostSoftLimit, + retryConnectionEstablishment: clientConfiguration.connectionPool.retryConnectionEstablishment, + preferHTTP1: clientConfiguration.httpVersion == .http1Only, + maximumConnectionUses: clientConfiguration.maximumUsesPerConnection, + preWarmedHTTP1ConnectionCount: clientConfiguration.connectionPool.preWarmedHTTP1ConnectionCount ) } @@ -95,6 +105,11 @@ final class HTTPConnectionPool { enum Unlocked { case createConnection(Connection.ID, on: EventLoop) case closeConnection(Connection, isShutdown: StateMachine.ConnectionAction.IsShutdown) + case closeConnectionAndCreateConnection( + close: Connection, + newConnectionID: Connection.ID, + on: EventLoop + ) case cleanupConnections(CleanupContext, isShutdown: StateMachine.ConnectionAction.IsShutdown) case migration( createConnections: [(Connection.ID, EventLoop)], @@ -147,9 +162,7 @@ final class HTTPConnectionPool { self.unlocked = Unlocked(connection: .none, request: .none) switch stateMachineAction.request { - case .cancelRequestTimeout(let requestID): - self.locked.request = .cancelRequestTimeout(requestID) - case .executeRequest(let request, let connection, cancelTimeout: let cancelTimeout): + case .executeRequest(let request, let connection, let cancelTimeout): if cancelTimeout { self.locked.request = .cancelRequestTimeout(request.id) } @@ -157,7 +170,7 @@ final class HTTPConnectionPool { case .executeRequestsAndCancelTimeouts(let requests, let connection): self.locked.request = .cancelRequestTimeouts(requests) self.unlocked.request = .executeRequests(requests, connection) - case .failRequest(let request, let error, cancelTimeout: let cancelTimeout): + case .failRequest(let request, let error, let cancelTimeout): if cancelTimeout { self.locked.request = .cancelRequestTimeout(request.id) } @@ -174,16 +187,31 @@ final class HTTPConnectionPool { switch stateMachineAction.connection { case .createConnection(let connectionID, on: let eventLoop): self.unlocked.connection = .createConnection(connectionID, on: eventLoop) - case .scheduleBackoffTimer(let connectionID, backoff: let backoff, on: let eventLoop): + case .scheduleBackoffTimer(let connectionID, let backoff, on: let eventLoop): self.locked.connection = .scheduleBackoffTimer(connectionID, backoff: backoff, on: eventLoop) case .scheduleTimeoutTimer(let connectionID, on: let eventLoop): self.locked.connection = .scheduleTimeoutTimer(connectionID, on: eventLoop) + case .scheduleTimeoutTimerAndCreateConnection(let timeoutID, let newConnectionID, let eventLoop): + self.locked.connection = .scheduleTimeoutTimer(timeoutID, on: eventLoop) + self.unlocked.connection = .createConnection(newConnectionID, on: eventLoop) case .cancelTimeoutTimer(let connectionID): self.locked.connection = .cancelTimeoutTimer(connectionID) - case .closeConnection(let connection, isShutdown: let isShutdown): + case .createConnectionAndCancelTimeoutTimer(let createdID, on: let eventLoop, cancelTimerID: let cancelID): + self.unlocked.connection = .createConnection(createdID, on: eventLoop) + self.locked.connection = .cancelTimeoutTimer(cancelID) + case .closeConnection(let connection, let isShutdown): self.unlocked.connection = .closeConnection(connection, isShutdown: isShutdown) - case .cleanupConnections(var cleanupContext, isShutdown: let isShutdown): - // + case .closeConnectionAndCreateConnection( + let closeConnection, + let newConnectionID, + let eventLoop + ): + self.unlocked.connection = .closeConnectionAndCreateConnection( + close: closeConnection, + newConnectionID: newConnectionID, + on: eventLoop + ) + case .cleanupConnections(var cleanupContext, let isShutdown): self.locked.connection = .cancelBackoffTimers(cleanupContext.connectBackoff) cleanupContext.connectBackoff = [] self.unlocked.connection = .cleanupConnections(cleanupContext, isShutdown: isShutdown) @@ -220,7 +248,7 @@ final class HTTPConnectionPool { private func runLockedConnectionAction(_ action: Actions.ConnectionAction.Locked) { switch action { - case .scheduleBackoffTimer(let connectionID, backoff: let backoff, on: let eventLoop): + case .scheduleBackoffTimer(let connectionID, let backoff, on: let eventLoop): self.scheduleConnectionStartBackoffTimer(connectionID, backoff, on: eventLoop) case .scheduleTimeoutTimer(let connectionID, on: let eventLoop): @@ -248,7 +276,7 @@ final class HTTPConnectionPool { self.cancelRequestTimeout(requestID) case .cancelRequestTimeouts(let requests): - requests.forEach { self.cancelRequestTimeout($0.id) } + for request in requests { self.cancelRequestTimeout(request.id) } case .none: break @@ -265,10 +293,13 @@ final class HTTPConnectionPool { case .createConnection(let connectionID, let eventLoop): self.createConnection(connectionID, on: eventLoop) - case .closeConnection(let connection, isShutdown: let isShutdown): - self.logger.trace("close connection", metadata: [ - "ahc-connection-id": "\(connection.id)", - ]) + case .closeConnection(let connection, let isShutdown): + self.logger.trace( + "close connection", + metadata: [ + "ahc-connection-id": "\(connection.id)" + ] + ) // we are not interested in the close promise... connection.close(promise: nil) @@ -277,7 +308,24 @@ final class HTTPConnectionPool { self.delegate.connectionPoolDidShutdown(self, unclean: unclean) } - case .cleanupConnections(let cleanupContext, isShutdown: let isShutdown): + case .closeConnectionAndCreateConnection( + let connectionToClose, + let newConnectionID, + let eventLoop + ): + self.logger.trace( + "closing and creating connection", + metadata: [ + "ahc-connection-id": "\(connectionToClose.id)" + ] + ) + + self.createConnection(newConnectionID, on: eventLoop) + + // we are not interested in the close promise... + connectionToClose.close(promise: nil) + + case .cleanupConnections(let cleanupContext, let isShutdown): for connection in cleanupContext.close { connection.close(promise: nil) } @@ -314,13 +362,15 @@ final class HTTPConnectionPool { connection.executeRequest(request.req) case .executeRequests(let requests, let connection): - requests.forEach { connection.executeRequest($0.req) } + for request in requests { + connection.executeRequest(request.req) + } case .failRequest(let request, let error): request.req.fail(error) case .failRequests(let requests, let error): - requests.forEach { $0.req.fail(error) } + for request in requests { request.req.fail(error) } case .none: break @@ -328,9 +378,12 @@ final class HTTPConnectionPool { } private func createConnection(_ connectionID: Connection.ID, on eventLoop: EventLoop) { - self.logger.trace("Opening fresh connection", metadata: [ - "ahc-connection-id": "\(connectionID)", - ]) + self.logger.trace( + "Opening fresh connection", + metadata: [ + "ahc-connection-id": "\(connectionID)" + ] + ) // Even though this function is called make it actually creates/establishes a connection. // TBD: Should we rename it? To what? self.connectionFactory.makeConnection( @@ -373,16 +426,19 @@ final class HTTPConnectionPool { } private func scheduleIdleTimerForConnection(_ connectionID: Connection.ID, on eventLoop: EventLoop) { - self.logger.trace("Schedule idle connection timeout timer", metadata: [ - "ahc-connection-id": "\(connectionID)", - ]) + self.logger.trace( + "Schedule idle connection timeout timer", + metadata: [ + "ahc-connection-id": "\(connectionID)" + ] + ) let scheduled = eventLoop.scheduleTask(in: self.idleConnectionTimeout) { // there might be a race between a cancelTimer call and the triggering // of this scheduled task. both want to acquire the lock self.modifyStateAndRunActions { stateMachine in if self._idleTimer.removeValue(forKey: connectionID) != nil { // The timer still exists. State Machines assumes it is alive - return stateMachine.connectionIdleTimeout(connectionID) + return stateMachine.connectionIdleTimeout(connectionID, on: eventLoop) } return .none } @@ -393,9 +449,12 @@ final class HTTPConnectionPool { } private func cancelIdleTimerForConnection(_ connectionID: Connection.ID) { - self.logger.trace("Cancel idle connection timeout timer", metadata: [ - "ahc-connection-id": "\(connectionID)", - ]) + self.logger.trace( + "Cancel idle connection timeout timer", + metadata: [ + "ahc-connection-id": "\(connectionID)" + ] + ) guard let cancelTimer = self._idleTimer.removeValue(forKey: connectionID) else { preconditionFailure("Expected to have an idle timer for connection \(connectionID) at this point.") } @@ -407,9 +466,12 @@ final class HTTPConnectionPool { _ timeAmount: TimeAmount, on eventLoop: EventLoop ) { - self.logger.trace("Schedule connection creation backoff timer", metadata: [ - "ahc-connection-id": "\(connectionID)", - ]) + self.logger.trace( + "Schedule connection creation backoff timer", + metadata: [ + "ahc-connection-id": "\(connectionID)" + ] + ) let scheduled = eventLoop.scheduleTask(in: timeAmount) { // there might be a race between a backoffTimer and the pool shutting down. @@ -437,99 +499,139 @@ final class HTTPConnectionPool { // MARK: - Protocol methods - extension HTTPConnectionPool: HTTPConnectionRequester { - func http1ConnectionCreated(_ connection: HTTP1Connection) { - self.logger.trace("successfully created connection", metadata: [ - "ahc-connection-id": "\(connection.id)", - "ahc-http-version": "http/1.1", - ]) + func http1ConnectionCreated(_ connection: HTTP1Connection.SendableView) { + self.logger.trace( + "successfully created connection", + metadata: [ + "ahc-connection-id": "\(connection.id)", + "ahc-http-version": "http/1.1", + ] + ) self.modifyStateAndRunActions { $0.newHTTP1ConnectionCreated(.http1_1(connection)) } } - func http2ConnectionCreated(_ connection: HTTP2Connection, maximumStreams: Int) { - self.logger.trace("successfully created connection", metadata: [ - "ahc-connection-id": "\(connection.id)", - "ahc-http-version": "http/2", - "ahc-max-streams": "\(maximumStreams)", - ]) + func http2ConnectionCreated(_ connection: HTTP2Connection.SendableView, maximumStreams: Int) { + self.logger.trace( + "successfully created connection", + metadata: [ + "ahc-connection-id": "\(connection.id)", + "ahc-http-version": "http/2", + "ahc-max-streams": "\(maximumStreams)", + ] + ) self.modifyStateAndRunActions { $0.newHTTP2ConnectionCreated(.http2(connection), maxConcurrentStreams: maximumStreams) } } func failedToCreateHTTPConnection(_ connectionID: HTTPConnectionPool.Connection.ID, error: Error) { - self.logger.debug("connection attempt failed", metadata: [ - "ahc-error": "\(error)", - "ahc-connection-id": "\(connectionID)", - ]) + self.logger.debug( + "connection attempt failed", + metadata: [ + "ahc-error": "\(error)", + "ahc-connection-id": "\(connectionID)", + ] + ) self.modifyStateAndRunActions { $0.failedToCreateNewConnection(error, connectionID: connectionID) } } + + func waitingForConnectivity(_ connectionID: HTTPConnectionPool.Connection.ID, error: Error) { + self.logger.debug( + "waiting for connectivity", + metadata: [ + "ahc-error": "\(error)", + "ahc-connection-id": "\(connectionID)", + ] + ) + self.modifyStateAndRunActions { + $0.waitingForConnectivity(error, connectionID: connectionID) + } + } } extension HTTPConnectionPool: HTTP1ConnectionDelegate { - func http1ConnectionClosed(_ connection: HTTP1Connection) { - self.logger.debug("connection closed", metadata: [ - "ahc-connection-id": "\(connection.id)", - "ahc-http-version": "http/1.1", - ]) + func http1ConnectionClosed(_ id: HTTPConnectionPool.Connection.ID) { + self.logger.debug( + "connection closed", + metadata: [ + "ahc-connection-id": "\(id)", + "ahc-http-version": "http/1.1", + ] + ) self.modifyStateAndRunActions { - $0.http1ConnectionClosed(connection.id) + $0.http1ConnectionClosed(id) } } - func http1ConnectionReleased(_ connection: HTTP1Connection) { - self.logger.trace("releasing connection", metadata: [ - "ahc-connection-id": "\(connection.id)", - "ahc-http-version": "http/1.1", - ]) + func http1ConnectionReleased(_ id: HTTPConnectionPool.Connection.ID) { + self.logger.trace( + "releasing connection", + metadata: [ + "ahc-connection-id": "\(id)", + "ahc-http-version": "http/1.1", + ] + ) self.modifyStateAndRunActions { - $0.http1ConnectionReleased(connection.id) + $0.http1ConnectionReleased(id) } } } extension HTTPConnectionPool: HTTP2ConnectionDelegate { - func http2Connection(_ connection: HTTP2Connection, newMaxStreamSetting: Int) { - self.logger.debug("new max stream setting", metadata: [ - "ahc-connection-id": "\(connection.id)", - "ahc-http-version": "http/2", - "ahc-max-streams": "\(newMaxStreamSetting)", - ]) + func http2Connection(_ id: HTTPConnectionPool.Connection.ID, newMaxStreamSetting: Int) { + self.logger.debug( + "new max stream setting", + metadata: [ + "ahc-connection-id": "\(id)", + "ahc-http-version": "http/2", + "ahc-max-streams": "\(newMaxStreamSetting)", + ] + ) self.modifyStateAndRunActions { - $0.newHTTP2MaxConcurrentStreamsReceived(connection.id, newMaxStreams: newMaxStreamSetting) + $0.newHTTP2MaxConcurrentStreamsReceived(id, newMaxStreams: newMaxStreamSetting) } } - func http2ConnectionGoAwayReceived(_ connection: HTTP2Connection) { - self.logger.debug("connection go away received", metadata: [ - "ahc-connection-id": "\(connection.id)", - "ahc-http-version": "http/2", - ]) + func http2ConnectionGoAwayReceived(_ id: HTTPConnectionPool.Connection.ID) { + self.logger.debug( + "connection go away received", + metadata: [ + "ahc-connection-id": "\(id)", + "ahc-http-version": "http/2", + ] + ) self.modifyStateAndRunActions { - $0.http2ConnectionGoAwayReceived(connection.id) + $0.http2ConnectionGoAwayReceived(id) } } - func http2ConnectionClosed(_ connection: HTTP2Connection) { - self.logger.debug("connection closed", metadata: [ - "ahc-connection-id": "\(connection.id)", - "ahc-http-version": "http/2", - ]) + func http2ConnectionClosed(_ id: HTTPConnectionPool.Connection.ID) { + self.logger.debug( + "connection closed", + metadata: [ + "ahc-connection-id": "\(id)", + "ahc-http-version": "http/2", + ] + ) self.modifyStateAndRunActions { - $0.http2ConnectionClosed(connection.id) + $0.http2ConnectionClosed(id) } } - func http2ConnectionStreamClosed(_ connection: HTTP2Connection, availableStreams: Int) { - self.logger.trace("stream closed", metadata: [ - "ahc-connection-id": "\(connection.id)", - "ahc-http-version": "http/2", - ]) + func http2ConnectionStreamClosed(_ id: HTTPConnectionPool.Connection.ID, availableStreams: Int) { + self.logger.trace( + "stream closed", + metadata: [ + "ahc-connection-id": "\(id)", + "ahc-http-version": "http/2", + ] + ) self.modifyStateAndRunActions { - $0.http2ConnectionStreamClosed(connection.id) + $0.http2ConnectionStreamClosed(id) } } } @@ -548,18 +650,18 @@ extension HTTPConnectionPool { typealias ID = Int private enum Reference { - case http1_1(HTTP1Connection) - case http2(HTTP2Connection) + case http1_1(HTTP1Connection.SendableView) + case http2(HTTP2Connection.SendableView) case __testOnly_connection(ID, EventLoop) } private let _ref: Reference - fileprivate static func http1_1(_ conn: HTTP1Connection) -> Self { + fileprivate static func http1_1(_ conn: HTTP1Connection.SendableView) -> Self { Connection(_ref: .http1_1(conn)) } - fileprivate static func http2(_ conn: HTTP2Connection) -> Self { + fileprivate static func http2(_ conn: HTTP2Connection.SendableView) -> Self { Connection(_ref: .http2(conn)) } @@ -631,7 +733,9 @@ extension HTTPConnectionPool { return lhsConn.id == rhsConn.id case (.http2(let lhsConn), .http2(let rhsConn)): return lhsConn.id == rhsConn.id - case (.__testOnly_connection(let lhsID, let lhsEventLoop), .__testOnly_connection(let rhsID, let rhsEventLoop)): + case ( + .__testOnly_connection(let lhsID, let lhsEventLoop), .__testOnly_connection(let rhsID, let rhsEventLoop) + ): return lhsID == rhsID && lhsEventLoop === rhsEventLoop default: return false @@ -712,7 +816,7 @@ struct EventLoopID: Hashable { } static func __testOnly_fakeID(_ id: Int) -> EventLoopID { - return EventLoopID(.__testOnly_fakeID(id)) + EventLoopID(.__testOnly_fakeID(id)) } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift index 2477e1154..bce55eb5b 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift @@ -132,7 +132,7 @@ import NIOSSL /// /// Use this handle to cancel the request, while it is waiting for a free connection, to execute the request. /// This protocol is only intended to be implemented by the `HTTPConnectionPool`. -protocol HTTPRequestScheduler { +protocol HTTPRequestScheduler: Sendable { /// Informs the task queuer that a request has been cancelled. func cancelRequest(_: HTTPSchedulableRequest) } @@ -176,16 +176,16 @@ protocol HTTPSchedulableRequest: HTTPExecutableRequest { /// A handle to the request executor. /// /// This protocol is implemented by the `HTTP1ClientChannelHandler`. -protocol HTTPRequestExecutor { +protocol HTTPRequestExecutor: Sendable { /// Writes a body part into the channel pipeline /// /// This method may be **called on any thread**. The executor needs to ensure thread safety. - func writeRequestBodyPart(_: IOData, request: HTTPExecutableRequest) + func writeRequestBodyPart(_: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise?) /// Signals that the request body stream has finished /// /// This method may be **called on any thread**. The executor needs to ensure thread safety. - func finishRequestBodyStream(_ task: HTTPExecutableRequest) + func finishRequestBodyStream(_ task: HTTPExecutableRequest, promise: EventLoopPromise?) /// Signals that more bytes from response body stream can be consumed. /// @@ -201,7 +201,7 @@ protocol HTTPRequestExecutor { func cancelRequest(_ task: HTTPExecutableRequest) } -protocol HTTPExecutableRequest: AnyObject { +protocol HTTPExecutableRequest: AnyObject, Sendable { /// The request's logger var logger: Logger { get } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine+Demand.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine+Demand.swift index 90578bc87..5c5b893e0 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine+Demand.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine+Demand.swift @@ -104,8 +104,8 @@ extension HTTPRequestStateMachine { // forwarded to the user. case .waitingForRead, - .waitingForDemand, - .waitingForReadOrDemand: + .waitingForDemand, + .waitingForReadOrDemand: return nil case .modifying: @@ -174,8 +174,8 @@ extension HTTPRequestStateMachine { return (buffer, .none) case .waitingForReadOrDemand(let buffer), - .waitingForRead(let buffer), - .waitingForDemand(let buffer): + .waitingForRead(let buffer), + .waitingForDemand(let buffer): // Normally this code path should never be hit. However there is one way to trigger // this: // diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift index fa520a865..e06389360 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift @@ -20,21 +20,24 @@ struct HTTPRequestStateMachine { fileprivate enum State { /// The initial state machine state. The only valid mutation is `start()`. The state will /// transitions to: - /// - `.waitForChannelToBecomeWritable` - /// - `.running(.streaming, .initialized)` (if the Channel is writable and if a request body is expected) - /// - `.running(.endSent, .initialized)` (if the Channel is writable and no request body is expected) + /// - `.waitForChannelToBecomeWritable` (if the channel becomes non writable while sending the header) + /// - `.sendingHead` if the channel is writable case initialized + /// Waiting for the channel to be writable. Valid transitions are: - /// - `.running(.streaming, .initialized)` (once the Channel is writable again and if a request body is expected) - /// - `.running(.endSent, .initialized)` (once the Channel is writable again and no request body is expected) + /// - `.running(.streaming, .waitingForHead)` (once the Channel is writable again and if a request body is expected) + /// - `.running(.endSent, .waitingForHead)` (once the Channel is writable again and no request body is expected) /// - `.failed` (if a connection error occurred) case waitForChannelToBecomeWritable(HTTPRequestHead, RequestFramingMetadata) + /// A request is on the wire. Valid transitions are: /// - `.finished` /// - `.failed` case running(RequestState, ResponseState) + /// The request has completed successfully case finished + /// The request has failed case failed(Error) @@ -55,7 +58,7 @@ struct HTTPRequestStateMachine { /// The request is streaming its request body. `expectedBodyLength` has a value, if the request header contained /// a `"content-length"` header field. If the request header contained a `"transfer-encoding" = "chunked"` /// header field, the `expectedBodyLength` is `nil`. - case streaming(expectedBodyLength: Int?, sentBodyBytes: Int, producer: ProducerControlState) + case streaming(expectedBodyLength: Int64?, sentBodyBytes: Int64, producer: ProducerControlState) /// The request has sent its request body and end. case endSent } @@ -70,21 +73,38 @@ struct HTTPRequestStateMachine { } enum Action { - /// A action to execute, when we consider a request "done". - enum FinalStreamAction { + /// A action to execute, when we consider a successful request "done". + enum FinalSuccessfulRequestAction { /// Close the connection case close /// If the server has replied, with a status of 200...300 before all data was sent, a request is considered succeeded, /// as soon as we wrote the request end onto the wire. - case sendRequestEnd + /// + /// The promise is an optional write promise. + case sendRequestEnd(EventLoopPromise?) + /// Do nothing. This is action is used, if the request failed, before we the request head was written onto the wire. + /// This might happen if the request is cancelled, or the request failed the soundness check. + case none + } + + /// A action to execute, when we consider a failed request "done". + enum FinalFailedRequestAction { + /// Close the connection + case close(EventLoopPromise?) /// Do nothing. This is action is used, if the request failed, before we the request head was written onto the wire. /// This might happen if the request is cancelled, or the request failed the soundness check. case none } - case sendRequestHead(HTTPRequestHead, startBody: Bool) - case sendBodyPart(IOData) - case sendRequestEnd + case sendRequestHead(HTTPRequestHead, sendEnd: Bool) + case notifyRequestHeadSendSuccessfully( + resumeRequestBodyStream: Bool, + startIdleTimer: Bool + ) + case sendBodyPart(IOData, EventLoopPromise?) + case sendRequestEnd(EventLoopPromise?) + case failSendBodyPart(Error, EventLoopPromise?) + case failSendStreamFinished(Error, EventLoopPromise?) case pauseRequestBodyStream case resumeRequestBodyStream @@ -92,8 +112,8 @@ struct HTTPRequestStateMachine { case forwardResponseHead(HTTPResponseHead, pauseRequestBodyStream: Bool) case forwardResponseBodyParts(CircularBuffer) - case failRequest(Error, FinalStreamAction) - case succeedRequest(FinalStreamAction, CircularBuffer) + case failRequest(Error, FinalFailedRequestAction) + case succeedRequest(FinalSuccessfulRequestAction, CircularBuffer) case read case wait @@ -141,10 +161,10 @@ struct HTTPRequestStateMachine { switch self.state { case .initialized, - .running(.streaming(_, _, producer: .producing), _), - .running(.endSent, _), - .finished, - .failed: + .running(.streaming(_, _, producer: .producing), _), + .running(.endSent, _), + .finished, + .failed: return .wait case .waitForChannelToBecomeWritable(let head, let metadata): @@ -176,11 +196,11 @@ struct HTTPRequestStateMachine { switch self.state { case .initialized, - .waitForChannelToBecomeWritable, - .running(.streaming(_, _, producer: .paused), _), - .running(.endSent, _), - .finished, - .failed: + .waitForChannelToBecomeWritable, + .running(.streaming(_, _, producer: .paused), _), + .running(.endSent, _), + .finished, + .failed: return .wait case .running(.streaming(let expectedBodyLength, let sentBodyBytes, producer: .producing), let responseState): @@ -199,20 +219,24 @@ struct HTTPRequestStateMachine { mutating func errorHappened(_ error: Error) -> Action { if let error = error as? NIOSSLError, - error == .uncleanShutdown, - let action = self.handleNIOSSLUncleanShutdownError() { + error == .uncleanShutdown, + let action = self.handleNIOSSLUncleanShutdownError() + { return action } switch self.state { case .initialized: - preconditionFailure("After the state machine has been initialized, start must be called immediately. Thus this state is unreachable") + preconditionFailure( + "After the state machine has been initialized, start must be called immediately. Thus this state is unreachable" + ) case .waitForChannelToBecomeWritable: // the request failed, before it was sent onto the wire. self.state = .failed(error) return .failRequest(error, .none) + case .running: self.state = .failed(error) - return .failRequest(error, .close) + return .failRequest(error, .close(nil)) case .finished, .failed: // ignore error @@ -226,14 +250,14 @@ struct HTTPRequestStateMachine { private mutating func handleNIOSSLUncleanShutdownError() -> Action? { switch self.state { case .running(.streaming, .waitingForHead), - .running(.endSent, .waitingForHead): + .running(.endSent, .waitingForHead): // if we received a NIOSSL.uncleanShutdown before we got an answer we should handle // this like a normal connection close. We will receive a call to channelInactive after // this error. return .wait case .running(.streaming, .receivingBody(let responseHead, _)), - .running(.endSent, .receivingBody(let responseHead, _)): + .running(.endSent, .receivingBody(let responseHead, _)): // This code is only reachable for request and responses, which we expect to have a body. // We depend on logic from the HTTPResponseDecoder here. The decoder will emit an // HTTPResponsePart.end right after the HTTPResponsePart.head, for every request with a @@ -242,7 +266,9 @@ struct HTTPRequestStateMachine { // For this reason we only need to check the "content-length" or "transfer-encoding" // headers here to determine if we are potentially in an EOF terminated response. - if responseHead.headers.contains(name: "content-length") || responseHead.headers.contains(name: "transfer-encoding") { + if responseHead.headers.contains(name: "content-length") + || responseHead.headers.contains(name: "transfer-encoding") + { // If we have already received the response head, the parser will ensure that we // receive a complete response, if the content-length or transfer-encoding header // was set. In this case we can ignore the NIOSSLError.uncleanShutdown. We will see @@ -254,19 +280,21 @@ struct HTTPRequestStateMachine { // we have received all necessary bytes. For this reason we forward the uncleanShutdown // error to the user. self.state = .failed(NIOSSLError.uncleanShutdown) - return .failRequest(NIOSSLError.uncleanShutdown, .close) + return .failRequest(NIOSSLError.uncleanShutdown, .close(nil)) case .waitForChannelToBecomeWritable, .running, .finished, .failed, .initialized, .modifying: return nil } } - mutating func requestStreamPartReceived(_ part: IOData) -> Action { + mutating func requestStreamPartReceived(_ part: IOData, promise: EventLoopPromise?) -> Action { switch self.state { case .initialized, - .waitForChannelToBecomeWritable, - .running(.endSent, _): - preconditionFailure("We must be in the request streaming phase, if we receive further body parts. Invalid state: \(self.state)") + .waitForChannelToBecomeWritable, + .running(.endSent, _): + preconditionFailure( + "We must be in the request streaming phase, if we receive further body parts. Invalid state: \(self.state)" + ) case .running(.streaming(_, _, let producerState), .receivingBody(let head, _)) where head.status.code >= 300: // If we have already received a response head with status >= 300, we won't send out any @@ -274,7 +302,7 @@ struct HTTPRequestStateMachine { // won't be interested. We expect that the producer has been informed to pause // producing. assert(producerState == .paused) - return .wait + return .failSendBodyPart(HTTPClientError.requestStreamCancelled, promise) case .running(.streaming(let expectedBodyLength, var sentBodyBytes, let producerState), let responseState): // We don't check the producer state here: @@ -287,13 +315,13 @@ struct HTTPRequestStateMachine { // pause. The reason for this is as follows: There might be thread synchronization // situations in which the producer might not have received the plea to pause yet. - if let expected = expectedBodyLength, sentBodyBytes + part.readableBytes > expected { + if let expected = expectedBodyLength, sentBodyBytes + Int64(part.readableBytes) > expected { let error = HTTPClientError.bodyLengthMismatch self.state = .failed(error) - return .failRequest(error, .close) + return .failRequest(error, .close(promise)) } - sentBodyBytes += part.readableBytes + sentBodyBytes += Int64(part.readableBytes) let requestState: RequestState = .streaming( expectedBodyLength: expectedBodyLength, @@ -303,10 +331,10 @@ struct HTTPRequestStateMachine { self.state = .running(requestState, responseState) - return .sendBodyPart(part) + return .sendBodyPart(part, promise) - case .failed: - return .wait + case .failed(let error): + return .failSendBodyPart(error, promise) case .finished: // A request may be finished, before we have send all parts. This might be the case if @@ -318,54 +346,59 @@ struct HTTPRequestStateMachine { // We may still receive something, here because of potential race conditions with the // producing thread. - return .wait + return .failSendBodyPart(HTTPClientError.requestStreamCancelled, promise) case .modifying: preconditionFailure("Invalid state: \(self.state)") } } - mutating func requestStreamFinished() -> Action { + mutating func requestStreamFinished(promise: EventLoopPromise?) -> Action { switch self.state { case .initialized, - .waitForChannelToBecomeWritable, - .running(.endSent, _): - preconditionFailure("A request body stream end is only expected if we are in state request streaming. Invalid state: \(self.state)") + .waitForChannelToBecomeWritable, + .running(.endSent, _): + preconditionFailure( + "A request body stream end is only expected if we are in state request streaming. Invalid state: \(self.state)" + ) case .running(.streaming(let expectedBodyLength, let sentBodyBytes, _), .waitingForHead): if let expected = expectedBodyLength, expected != sentBodyBytes { let error = HTTPClientError.bodyLengthMismatch self.state = .failed(error) - return .failRequest(error, .close) + return .failRequest(error, .close(promise)) } self.state = .running(.endSent, .waitingForHead) - return .sendRequestEnd + return .sendRequestEnd(promise) - case .running(.streaming(let expectedBodyLength, let sentBodyBytes, _), .receivingBody(let head, let streamState)): + case .running( + .streaming(let expectedBodyLength, let sentBodyBytes, _), + .receivingBody(let head, let streamState) + ): assert(head.status.code < 300) if let expected = expectedBodyLength, expected != sentBodyBytes { let error = HTTPClientError.bodyLengthMismatch self.state = .failed(error) - return .failRequest(error, .close) + return .failRequest(error, .close(promise)) } self.state = .running(.endSent, .receivingBody(head, streamState)) - return .sendRequestEnd + return .sendRequestEnd(promise) case .running(.streaming(let expectedBodyLength, let sentBodyBytes, _), .endReceived): if let expected = expectedBodyLength, expected != sentBodyBytes { let error = HTTPClientError.bodyLengthMismatch self.state = .failed(error) - return .failRequest(error, .close) + return .failRequest(error, .close(promise)) } self.state = .finished - return .succeedRequest(.sendRequestEnd, .init()) + return .succeedRequest(.sendRequestEnd(promise), .init()) - case .failed: - return .wait + case .failed(let error): + return .failSendStreamFinished(error, promise) case .finished: // A request may be finished, before we have send all parts. This might be the case if @@ -377,7 +410,7 @@ struct HTTPRequestStateMachine { // We may still receive something, here because of potential race conditions with the // producing thread. - return .wait + return .failSendStreamFinished(HTTPClientError.requestStreamCancelled, promise) case .modifying: preconditionFailure("Invalid state: \(self.state)") @@ -398,7 +431,7 @@ struct HTTPRequestStateMachine { case .running: let error = HTTPClientError.cancelled self.state = .failed(error) - return .failRequest(error, .close) + return .failRequest(error, .close(nil)) case .finished: return .wait @@ -435,11 +468,11 @@ struct HTTPRequestStateMachine { mutating func read() -> Action { switch self.state { case .initialized, - .waitForChannelToBecomeWritable, - .running(_, .waitingForHead), - .running(_, .endReceived), - .finished, - .failed: + .waitForChannelToBecomeWritable, + .running(_, .waitingForHead), + .running(_, .endReceived), + .finished, + .failed: // If we are not in the middle of streaming the response body, we always want to get // more data... return .read @@ -472,11 +505,11 @@ struct HTTPRequestStateMachine { mutating func channelReadComplete() -> Action { switch self.state { case .initialized, - .waitForChannelToBecomeWritable, - .running(_, .waitingForHead), - .running(_, .endReceived), - .finished, - .failed: + .waitForChannelToBecomeWritable, + .running(_, .waitingForHead), + .running(_, .endReceived), + .finished, + .failed: return .wait case .running(let requestState, .receivingBody(let head, var streamState)): @@ -507,7 +540,9 @@ struct HTTPRequestStateMachine { switch self.state { case .initialized, .waitForChannelToBecomeWritable: - preconditionFailure("How can we receive a response head before sending a request head ourselves") + preconditionFailure( + "How can we receive a response head before sending a request head ourselves \(self.state)" + ) case .running(.streaming(let expectedBodyLength, let sentBodyBytes, producer: .paused), .waitingForHead): self.state = .running( @@ -525,7 +560,11 @@ struct HTTPRequestStateMachine { return .forwardResponseHead(head, pauseRequestBodyStream: true) } else { self.state = .running( - .streaming(expectedBodyLength: expectedBodyLength, sentBodyBytes: sentBodyBytes, producer: .producing), + .streaming( + expectedBodyLength: expectedBodyLength, + sentBodyBytes: sentBodyBytes, + producer: .producing + ), .receivingBody(head, .init()) ) return .forwardResponseHead(head, pauseRequestBodyStream: false) @@ -536,7 +575,9 @@ struct HTTPRequestStateMachine { return .forwardResponseHead(head, pauseRequestBodyStream: false) case .running(_, .receivingBody), .running(_, .endReceived), .finished: - preconditionFailure("How can we successfully finish the request, before having received a head. Invalid state: \(self.state)") + preconditionFailure( + "How can we successfully finish the request, before having received a head. Invalid state: \(self.state)" + ) case .failed: return .wait @@ -548,10 +589,14 @@ struct HTTPRequestStateMachine { mutating func receivedHTTPResponseBodyPart(_ body: ByteBuffer) -> Action { switch self.state { case .initialized, .waitForChannelToBecomeWritable: - preconditionFailure("How can we receive a response head before sending a request head ourselves. Invalid state: \(self.state)") + preconditionFailure( + "How can we receive a response head before completely sending a request head ourselves. Invalid state: \(self.state)" + ) case .running(_, .waitingForHead): - preconditionFailure("How can we receive a response body, if we haven't received a head. Invalid state: \(self.state)") + preconditionFailure( + "How can we receive a response body, if we haven't received a head. Invalid state: \(self.state)" + ) case .running(let requestState, .receivingBody(let head, var responseStreamState)): return self.avoidingStateMachineCoW { state -> Action in @@ -561,7 +606,9 @@ struct HTTPRequestStateMachine { } case .running(_, .endReceived), .finished: - preconditionFailure("How can we successfully finish the request, before having received a head. Invalid state: \(self.state)") + preconditionFailure( + "How can we successfully finish the request, before having received a head. Invalid state: \(self.state)" + ) case .failed: return .wait @@ -574,20 +621,31 @@ struct HTTPRequestStateMachine { private mutating func receivedHTTPResponseEnd() -> Action { switch self.state { case .initialized, .waitForChannelToBecomeWritable: - preconditionFailure("How can we receive a response head before sending a request head ourselves. Invalid state: \(self.state)") + preconditionFailure( + "How can we receive a response end before completely sending a request head ourselves. Invalid state: \(self.state)" + ) case .running(_, .waitingForHead): - preconditionFailure("How can we receive a response end, if we haven't a received a head. Invalid state: \(self.state)") + preconditionFailure( + "How can we receive a response end, if we haven't a received a head. Invalid state: \(self.state)" + ) - case .running(.streaming(let expectedBodyLength, let sentBodyBytes, let producerState), .receivingBody(let head, var responseStreamState)) - where head.status.code < 300: + case .running( + .streaming(let expectedBodyLength, let sentBodyBytes, let producerState), + .receivingBody(let head, var responseStreamState) + ) + where head.status.code < 300: return self.avoidingStateMachineCoW { state -> Action in let (remainingBuffer, connectionAction) = responseStreamState.end() switch connectionAction { case .none: state = .running( - .streaming(expectedBodyLength: expectedBodyLength, sentBodyBytes: sentBodyBytes, producer: producerState), + .streaming( + expectedBodyLength: expectedBodyLength, + sentBodyBytes: sentBodyBytes, + producer: producerState + ), .endReceived ) return .forwardResponseBodyParts(remainingBuffer) @@ -597,13 +655,16 @@ struct HTTPRequestStateMachine { // the request is still uploading, we will not be able to finish the upload. For // this reason we can fail the request here. state = .failed(HTTPClientError.remoteConnectionClosed) - return .failRequest(HTTPClientError.remoteConnectionClosed, .close) + return .failRequest(HTTPClientError.remoteConnectionClosed, .close(nil)) } } case .running(.streaming(_, _, let producerState), .receivingBody(let head, var responseStreamState)): assert(head.status.code >= 300) - assert(producerState == .paused, "Expected to have paused the request body stream, when the head was received. Invalid state: \(self.state)") + assert( + producerState == .paused, + "Expected to have paused the request body stream, when the head was received. Invalid state: \(self.state)" + ) return self.avoidingStateMachineCoW { state -> Action in // We can ignore the connectionAction from the responseStreamState, since the @@ -626,7 +687,9 @@ struct HTTPRequestStateMachine { } case .running(_, .endReceived), .finished: - preconditionFailure("How can we receive a response end, if another one was already received. Invalid state: \(self.state)") + preconditionFailure( + "How can we receive a response end, if another one was already received. Invalid state: \(self.state)" + ) case .failed: return .wait @@ -639,9 +702,11 @@ struct HTTPRequestStateMachine { mutating func demandMoreResponseBodyParts() -> Action { switch self.state { case .initialized, - .running(_, .waitingForHead), - .waitForChannelToBecomeWritable: - preconditionFailure("The response is expected to only ask for more data after the response head was forwarded") + .running(_, .waitingForHead), + .waitForChannelToBecomeWritable: + preconditionFailure( + "The response is expected to only ask for more data after the response head was forwarded \(self.state)" + ) case .running(let requestState, .receivingBody(let head, var responseStreamState)): return self.avoidingStateMachineCoW { state -> Action in @@ -651,8 +716,8 @@ struct HTTPRequestStateMachine { } case .running(_, .endReceived), - .finished, - .failed: + .finished, + .failed: return .wait case .modifying: @@ -663,14 +728,16 @@ struct HTTPRequestStateMachine { mutating func idleReadTimeoutTriggered() -> Action { switch self.state { case .initialized, - .waitForChannelToBecomeWritable, - .running(.streaming, _): - preconditionFailure("We only schedule idle read timeouts after we have sent the complete request. Invalid state: \(self.state)") + .waitForChannelToBecomeWritable, + .running(.streaming, _): + preconditionFailure( + "We only schedule idle read timeouts after we have sent the complete request. Invalid state: \(self.state)" + ) case .running(.endSent, .waitingForHead), .running(.endSent, .receivingBody): let error = HTTPClientError.readTimeout self.state = .failed(error) - return .failRequest(error, .close) + return .failRequest(error, .close(nil)) case .running(.endSent, .endReceived): preconditionFailure("Invalid state. This state should be: .finished") @@ -683,19 +750,84 @@ struct HTTPRequestStateMachine { } } + mutating func idleWriteTimeoutTriggered() -> Action { + switch self.state { + case .initialized, + .waitForChannelToBecomeWritable: + preconditionFailure( + "We only schedule idle write timeouts while the request is being sent. Invalid state: \(self.state)" + ) + + case .running(.streaming, _): + let error = HTTPClientError.writeTimeout + self.state = .failed(error) + return .failRequest(error, .close(nil)) + + case .running(.endSent, _): + preconditionFailure("Invalid state. This state should be: .finished") + + case .finished, .failed: + return .wait + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } + private mutating func startSendingRequest(head: HTTPRequestHead, metadata: RequestFramingMetadata) -> Action { - switch metadata.body { - case .stream: - self.state = .running(.streaming(expectedBodyLength: nil, sentBodyBytes: 0, producer: .producing), .waitingForHead) - return .sendRequestHead(head, startBody: true) - case .fixedSize(0): + let length = metadata.body.expectedLength + if length == 0 { // no body self.state = .running(.endSent, .waitingForHead) - return .sendRequestHead(head, startBody: false) - case .fixedSize(let length): - // length is greater than zero and we therefore have a body to send - self.state = .running(.streaming(expectedBodyLength: length, sentBodyBytes: 0, producer: .producing), .waitingForHead) - return .sendRequestHead(head, startBody: true) + return .sendRequestHead(head, sendEnd: true) + } else { + self.state = .running( + .streaming(expectedBodyLength: length, sentBodyBytes: 0, producer: .paused), + .waitingForHead + ) + return .sendRequestHead(head, sendEnd: false) + } + } + + mutating func headSent() -> Action { + switch self.state { + case .initialized, .waitForChannelToBecomeWritable, .finished: + preconditionFailure("Not a valid transition after `.sendingHeader`: \(self.state)") + + case .running(.streaming(let expectedBodyLength, let sentBodyBytes, producer: .paused), let responseState): + let startProducing = self.isChannelWritable && expectedBodyLength != sentBodyBytes + self.state = .running( + .streaming( + expectedBodyLength: expectedBodyLength, + sentBodyBytes: sentBodyBytes, + producer: startProducing ? .producing : .paused + ), + responseState + ) + return .notifyRequestHeadSendSuccessfully( + resumeRequestBodyStream: startProducing, + startIdleTimer: false + ) + case .running(.endSent, _): + return .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) + case .running(.streaming(_, _, producer: .producing), _): + preconditionFailure( + "request body producing can not start before we have successfully send the header \(self.state)" + ) + case .failed: + return .wait + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } +} + +extension RequestFramingMetadata.Body { + var expectedLength: Int64? { + switch self { + case .fixedSize(let length): return length + case .stream: return nil } } } @@ -754,7 +886,8 @@ extension HTTPRequestStateMachine: CustomStringConvertible { case .waitForChannelToBecomeWritable: return "HTTPRequestStateMachine(.waitForChannelToBecomeWritable, isWritable: \(self.isChannelWritable))" case .running(let requestState, let responseState): - return "HTTPRequestStateMachine(.running(request: \(requestState), response: \(responseState)), isWritable: \(self.isChannelWritable))" + return + "HTTPRequestStateMachine(.running(request: \(requestState), response: \(responseState)), isWritable: \(self.isChannelWritable))" case .finished: return "HTTPRequestStateMachine(.finished, isWritable: \(self.isChannelWritable))" case .failed(let error): @@ -768,7 +901,7 @@ extension HTTPRequestStateMachine: CustomStringConvertible { extension HTTPRequestStateMachine.RequestState: CustomStringConvertible { var description: String { switch self { - case .streaming(expectedBodyLength: let expected, let sent, producer: let producer): + case .streaming(expectedBodyLength: let expected, let sent, let producer): return ".streaming(sent: \(expected != nil ? String(expected!) : "-"), sent: \(sent), producer: \(producer)" case .endSent: return ".endSent" diff --git a/Sources/AsyncHTTPClient/ConnectionPool/RequestBodyLength.swift b/Sources/AsyncHTTPClient/ConnectionPool/RequestBodyLength.swift index 38d90e057..58ba694a7 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/RequestBodyLength.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/RequestBodyLength.swift @@ -12,11 +12,13 @@ // //===----------------------------------------------------------------------===// +import NIOCore + /// - Note: use `HTTPClientRequest.Body.Length` if you want to expose `RequestBodyLength` publicly @usableFromInline -internal enum RequestBodyLength: Hashable { +internal enum RequestBodyLength: Hashable, Sendable { /// size of the request body is not known before starting the request case unknown /// size of the request body is fixed and exactly `count` bytes - case known(_ count: Int) + case known(_ count: Int64) } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/RequestFramingMetadata.swift b/Sources/AsyncHTTPClient/ConnectionPool/RequestFramingMetadata.swift index 98080e364..033060a99 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/RequestFramingMetadata.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/RequestFramingMetadata.swift @@ -15,7 +15,7 @@ struct RequestFramingMetadata: Hashable { enum Body: Hashable { case stream - case fixedSize(Int) + case fixedSize(Int64) } var connectionClose: Bool diff --git a/Sources/AsyncHTTPClient/ConnectionPool/RequestOptions.swift b/Sources/AsyncHTTPClient/ConnectionPool/RequestOptions.swift index 2092498d8..903f962e5 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/RequestOptions.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/RequestOptions.swift @@ -17,16 +17,28 @@ import NIOCore struct RequestOptions { /// The maximal `TimeAmount` that is allowed to pass between `channelRead`s from the Channel. var idleReadTimeout: TimeAmount? + /// The maximal `TimeAmount` that is allowed to pass between `write`s into the Channel. + var idleWriteTimeout: TimeAmount? + /// DNS overrides. + var dnsOverride: [String: String] - init(idleReadTimeout: TimeAmount?) { + init( + idleReadTimeout: TimeAmount?, + idleWriteTimeout: TimeAmount?, + dnsOverride: [String: String] + ) { self.idleReadTimeout = idleReadTimeout + self.idleWriteTimeout = idleWriteTimeout + self.dnsOverride = dnsOverride } } extension RequestOptions { static func fromClientConfiguration(_ configuration: HTTPClient.Configuration) -> Self { RequestOptions( - idleReadTimeout: configuration.timeout.read + idleReadTimeout: configuration.timeout.read, + idleWriteTimeout: configuration.timeout.write, + dnsOverride: configuration.dnsOverride ) } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+Backoff.swift b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+Backoff.swift index 4aec9f6fe..e2ef564a5 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+Backoff.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+Backoff.swift @@ -13,8 +13,13 @@ //===----------------------------------------------------------------------===// import NIOCore + #if canImport(Darwin) import func Darwin.pow +#elseif canImport(Musl) +import func Musl.pow +#elseif canImport(Android) +import func Android.pow #else import func Glibc.pow #endif @@ -56,7 +61,7 @@ extension HTTPConnectionPool { // Calculate a 3% jitter range let jitterRange = (backoff.nanoseconds / 100) * 3 // Pick a random element from the range +/- jitter range. - let jitter: TimeAmount = .nanoseconds((-jitterRange...jitterRange).randomElement()!) + let jitter: TimeAmount = .nanoseconds(Int64.random(in: -jitterRange...jitterRange)) let jitteredBackoff = backoff + jitter return jitteredBackoff } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP1Connections.swift b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP1Connections.swift index cdbf02394..3cdf51869 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP1Connections.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP1Connections.swift @@ -19,15 +19,15 @@ extension HTTPConnectionPool { private struct HTTP1ConnectionState { enum State { /// the connection is creating a connection. Valid transitions are to: .backingOff, .idle, and .closed - case starting + case starting(maximumUses: Int?) /// the connection is waiting to retry the establishing a connection. Valid transitions are to: .closed. /// This means, the connection can be removed from the connections without cancelling external /// state. The connection state can then be replaced by a new one. case backingOff /// the connection is idle for a new request. Valid transitions to: .leased and .closed - case idle(Connection, since: NIODeadline) + case idle(Connection, since: NIODeadline, remainingUses: Int?) /// the connection is leased and running for a request. Valid transitions to: .idle and .closed - case leased(Connection) + case leased(Connection, remainingUses: Int?) /// the connection is closed. final state. case closed } @@ -36,10 +36,10 @@ extension HTTPConnectionPool { let connectionID: Connection.ID let eventLoop: EventLoop - init(connectionID: Connection.ID, eventLoop: EventLoop) { + init(connectionID: Connection.ID, eventLoop: EventLoop, maximumUses: Int?) { self.connectionID = connectionID self.eventLoop = eventLoop - self.state = .starting + self.state = .starting(maximumUses: maximumUses) } var isConnecting: Bool { @@ -69,6 +69,19 @@ extension HTTPConnectionPool { } } + var idleAndNoRemainingUses: Bool { + switch self.state { + case .idle(_, since: _, let remainingUses): + if let remainingUses = remainingUses { + return remainingUses <= 0 + } else { + return false + } + case .backingOff, .starting, .leased, .closed: + return false + } + } + var canOrWillBeAbleToExecuteRequests: Bool { switch self.state { case .leased, .backingOff, .idle, .starting: @@ -89,7 +102,7 @@ extension HTTPConnectionPool { var idleSince: NIODeadline? { switch self.state { - case .idle(_, since: let idleSince): + case .idle(_, since: let idleSince, _): return idleSince case .backingOff, .starting, .leased, .closed: return nil @@ -107,8 +120,8 @@ extension HTTPConnectionPool { mutating func connected(_ connection: Connection) { switch self.state { - case .starting: - self.state = .idle(connection, since: .now()) + case .starting(maximumUses: let maxUses): + self.state = .idle(connection, since: .now(), remainingUses: maxUses) case .backingOff, .idle, .leased, .closed: preconditionFailure("Invalid state: \(self.state)") } @@ -126,8 +139,8 @@ extension HTTPConnectionPool { mutating func lease() -> Connection { switch self.state { - case .idle(let connection, since: _): - self.state = .leased(connection) + case .idle(let connection, since: _, let remainingUses): + self.state = .leased(connection, remainingUses: remainingUses.map { $0 - 1 }) return connection case .backingOff, .starting, .leased, .closed: preconditionFailure("Invalid state: \(self.state)") @@ -136,8 +149,8 @@ extension HTTPConnectionPool { mutating func release() { switch self.state { - case .leased(let connection): - self.state = .idle(connection, since: .now()) + case .leased(let connection, let remainingUses): + self.state = .idle(connection, since: .now(), remainingUses: remainingUses) case .backingOff, .starting, .idle, .closed: preconditionFailure("Invalid state: \(self.state)") } @@ -145,7 +158,7 @@ extension HTTPConnectionPool { mutating func close() -> Connection { switch self.state { - case .idle(let connection, since: _): + case .idle(let connection, since: _, remainingUses: _): self.state = .closed return connection case .backingOff, .starting, .leased, .closed: @@ -153,10 +166,22 @@ extension HTTPConnectionPool { } } - mutating func fail() { + enum FailAction { + case removeConnection + case none + } + + mutating func fail() -> FailAction { switch self.state { - case .starting, .backingOff, .idle, .leased: + case .starting: + // If the connection fails while we are starting it, the fail call raced with + // `failedToConnect` (promises are succeeded or failed before channel handler + // callbacks). let's keep the state in `starting`, so that `failedToConnect` can + // create a backoff timer. + return .none + case .backingOff, .idle, .leased: self.state = .closed + return .removeConnection case .closed: preconditionFailure("Invalid state: \(self.state)") } @@ -188,14 +213,16 @@ extension HTTPConnectionPool { return .removeConnection case .starting: return .keepConnection - case .idle(let connection, since: _): + case .idle(let connection, since: _, remainingUses: _): context.close.append(connection) return .removeConnection - case .leased(let connection): + case .leased(let connection, remainingUses: _): context.cancel.append(connection) return .keepConnection case .closed: - preconditionFailure("Unexpected state: Did not expect to have connections with this state in the state machine: \(self.state)") + preconditionFailure( + "Unexpected state: Did not expect to have connections with this state in the state machine: \(self.state)" + ) } } @@ -212,14 +239,16 @@ extension HTTPConnectionPool { case .backingOff: context.backingOff.append((self.connectionID, self.eventLoop)) return .removeConnection - case .idle(let connection, since: _): + case .idle(let connection, since: _, remainingUses: _): // Idle connections can be removed right away context.close.append(connection) return .removeConnection case .leased: return .keepConnection case .closed: - preconditionFailure("Unexpected state: Did not expect to have connections with this state in the state machine: \(self.state)") + preconditionFailure( + "Unexpected state: Did not expect to have connections with this state in the state machine: \(self.state)" + ) } } } @@ -243,31 +272,32 @@ extension HTTPConnectionPool { /// The index after which you will find the connections for requests with `EventLoop` /// requirements in `connections`. private var overflowIndex: Array.Index - - init(maximumConcurrentConnections: Int, generator: Connection.ID.Generator) { + /// The number of times each connection can be used before it is closed and replaced. + private let maximumConnectionUses: Int? + /// How many pre-warmed connections we should create. + private let preWarmedConnectionCount: Int + + init( + maximumConcurrentConnections: Int, + generator: Connection.ID.Generator, + maximumConnectionUses: Int?, + preWarmedHTTP1ConnectionCount: Int + ) { self.connections = [] - self.connections.reserveCapacity(maximumConcurrentConnections) + self.connections.reserveCapacity(min(maximumConcurrentConnections, 1024)) self.overflowIndex = self.connections.endIndex self.maximumConcurrentConnections = maximumConcurrentConnections self.generator = generator + self.maximumConnectionUses = maximumConnectionUses + self.preWarmedConnectionCount = preWarmedHTTP1ConnectionCount } var stats: Stats { - var stats = Stats() - // all additions here can be unchecked, since we will have at max self.connections.count - // which itself is an Int. For this reason we will never overflow. - for connectionState in self.connections { - if connectionState.isConnecting { - stats.connecting &+= 1 - } else if connectionState.isBackingOff { - stats.backingOff &+= 1 - } else if connectionState.isLeased { - stats.leased &+= 1 - } else if connectionState.isIdle { - stats.idle &+= 1 - } - } - return stats + self.connectionStats(in: self.connections.startIndex.. Int { - return self.connections[self.overflowIndex..) -> Stats { + var stats = Stats() + // all additions here can be unchecked, since we will have at max self.connections.count + // which itself is an Int. For this reason we will never overflow. + for connectionState in self.connections[range] { + if connectionState.isConnecting { + stats.connecting &+= 1 + } else if connectionState.isBackingOff { + stats.backingOff &+= 1 + } else if connectionState.isLeased { + stats.leased &+= 1 + } else if connectionState.isIdle { + stats.idle &+= 1 + } + } + return stats + } + // MARK: - Mutations - /// A connection's use. Did it serve in the pool or was it specialized for an `EventLoop`? @@ -323,6 +371,8 @@ extension HTTPConnectionPool { /// The connection's use. Either general purpose or for requests with `EventLoop` /// requirements. var use: ConnectionUse + /// Whether the connection should be closed. + var shouldBeClosed: Bool } /// Information around the failed/closed connection. @@ -345,14 +395,22 @@ extension HTTPConnectionPool { mutating func createNewConnection(on eventLoop: EventLoop) -> Connection.ID { precondition(self.canGrow) - let connection = HTTP1ConnectionState(connectionID: self.generator.next(), eventLoop: eventLoop) + let connection = HTTP1ConnectionState( + connectionID: self.generator.next(), + eventLoop: eventLoop, + maximumUses: self.maximumConnectionUses + ) self.connections.insert(connection, at: self.overflowIndex) self.overflowIndex = self.connections.index(after: self.overflowIndex) return connection.connectionID } mutating func createNewOverflowConnection(on eventLoop: EventLoop) -> Connection.ID { - let connection = HTTP1ConnectionState(connectionID: self.generator.next(), eventLoop: eventLoop) + let connection = HTTP1ConnectionState( + connectionID: self.generator.next(), + eventLoop: eventLoop, + maximumUses: self.maximumConnectionUses + ) self.connections.append(connection) return connection.connectionID } @@ -369,7 +427,10 @@ extension HTTPConnectionPool { guard let index = self.connections.firstIndex(where: { $0.connectionID == connection.id }) else { preconditionFailure("There is a new connection that we didn't request!") } - precondition(connection.eventLoop === self.connections[index].eventLoop, "Expected the new connection to be on EL") + precondition( + connection.eventLoop === self.connections[index].eventLoop, + "Expected the new connection to be on EL" + ) self.connections[index].connected(connection) let context = self.generateIdleConnectionContextForConnection(at: index) return (index, context) @@ -484,7 +545,8 @@ extension HTTPConnectionPool { precondition(self.connections[index].isClosed) let newConnection = HTTP1ConnectionState( connectionID: self.generator.next(), - eventLoop: self.connections[index].eventLoop + eventLoop: self.connections[index].eventLoop, + maximumUses: self.maximumConnectionUses ) self.connections[index] = newConnection @@ -509,23 +571,28 @@ extension HTTPConnectionPool { } let use: ConnectionUse - self.connections[index].fail() - let eventLoop = self.connections[index].eventLoop - let starting: Int - if index < self.overflowIndex { - use = .generalPurpose - starting = self.startingGeneralPurposeConnections - } else { - use = .eventLoop(eventLoop) - starting = self.startingEventLoopConnections(on: eventLoop) - } + switch self.connections[index].fail() { + case .removeConnection: + let eventLoop = self.connections[index].eventLoop + let starting: Int + if index < self.overflowIndex { + use = .generalPurpose + starting = self.startingGeneralPurposeConnections + } else { + use = .eventLoop(eventLoop) + starting = self.startingEventLoopConnections(on: eventLoop) + } - let context = FailedConnectionContext( - eventLoop: eventLoop, - use: use, - connectionsStartingForUseCase: starting - ) - return (index, context) + let context = FailedConnectionContext( + eventLoop: eventLoop, + use: use, + connectionsStartingForUseCase: starting + ) + return (index, context) + + case .none: + return nil + } } // MARK: Migration @@ -562,7 +629,12 @@ extension HTTPConnectionPool { backingOff: [(Connection.ID, EventLoop)] ) { for (connectionID, eventLoop) in starting { - let newConnection = HTTP1ConnectionState(connectionID: connectionID, eventLoop: eventLoop) + let newConnection = HTTP1ConnectionState( + connectionID: connectionID, + eventLoop: eventLoop, + maximumUses: self.maximumConnectionUses + ) + self.connections.insert(newConnection, at: self.overflowIndex) /// If we can grow, we mark the connection as a general purpose connection. /// Otherwise, it will be an overflow connection which is only used once for requests with a required event loop @@ -572,9 +644,14 @@ extension HTTPConnectionPool { } for (connectionID, eventLoop) in backingOff { - var backingOffConnection = HTTP1ConnectionState(connectionID: connectionID, eventLoop: eventLoop) + var backingOffConnection = HTTP1ConnectionState( + connectionID: connectionID, + eventLoop: eventLoop, + maximumUses: self.maximumConnectionUses + ) // TODO: Maybe we want to add a static init for backing off connections to HTTP1ConnectionState backingOffConnection.failedToConnect() + self.connections.insert(backingOffConnection, at: self.overflowIndex) /// If we can grow, we mark the connection as a general purpose connection. /// Otherwise, it will be an overflow connection which is only used once for requests with a required event loop @@ -602,21 +679,23 @@ extension HTTPConnectionPool { ) -> [(Connection.ID, EventLoop)] { // create new connections for requests with a required event loop - // we may already start connections for those requests and do not want to start to many + // we may already start connections for those requests and do not want to start too many let startingRequiredEventLoopConnectionCount = Dictionary( self.connections[self.overflowIndex.. [(Connection.ID, EventLoop)] in // We need a connection for each queued request with a required event loop. // Therefore, we look how many request we have queued for a given `eventLoop` and // how many connections we are already starting on the given `eventLoop`. // If we have not enough, we will create additional connections to have at least // on connection per request. - let connectionsToStart = requestCount - startingRequiredEventLoopConnectionCount[eventLoop.id, default: 0] + let connectionsToStart = + requestCount - startingRequiredEventLoopConnectionCount[eventLoop.id, default: 0] return stride(from: 0, to: connectionsToStart, by: 1).lazy.map { _ in (self.createNewOverflowConnection(on: eventLoop), eventLoop) } @@ -631,7 +710,8 @@ extension HTTPConnectionPool { // event loop we will continue with the event loop with the second most queued requests // and so on and so forth. The `generalPurposeRequestCountGroupedByPreferredEventLoop` // array is already ordered so we can just iterate over it without sorting by request count. - let newGeneralPurposeConnections: [(Connection.ID, EventLoop)] = generalPurposeRequestCountGroupedByPreferredEventLoop + let newGeneralPurposeConnections: [(Connection.ID, EventLoop)] = + generalPurposeRequestCountGroupedByPreferredEventLoop // we do not want to allocated intermediate arrays. .lazy // we flatten the grouped list of event loops by lazily repeating the event loop @@ -690,7 +770,8 @@ extension HTTPConnectionPool { } else { use = .eventLoop(eventLoop) } - return IdleConnectionContext(eventLoop: eventLoop, use: use) + let hasNoRemainingUses = self.connections[index].idleAndNoRemainingUses + return IdleConnectionContext(eventLoop: eventLoop, use: use, shouldBeClosed: hasNoRemainingUses) } private func findIdleConnection(onPreferred preferredEL: EventLoop) -> Int? { @@ -788,6 +869,10 @@ extension HTTPConnectionPool { var leased: Int = 0 var connecting: Int = 0 var backingOff: Int = 0 + + var nonLeased: Int { + self.idle + self.connecting + self.backingOff + } } } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP1StateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP1StateMachine.swift index 6b3f7352e..395064377 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP1StateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP1StateMachine.swift @@ -17,6 +17,7 @@ import NIOCore extension HTTPConnectionPool { struct HTTP1StateMachine { typealias Action = HTTPConnectionPool.StateMachine.Action + typealias RequestAction = HTTPConnectionPool.StateMachine.RequestAction typealias ConnectionMigrationAction = HTTPConnectionPool.StateMachine.ConnectionMigrationAction typealias EstablishedAction = HTTPConnectionPool.StateMachine.EstablishedAction typealias EstablishedConnectionAction = HTTPConnectionPool.StateMachine.EstablishedConnectionAction @@ -29,16 +30,27 @@ extension HTTPConnectionPool { private(set) var requests: RequestQueue private(set) var lifecycleState: StateMachine.LifecycleState + /// The property was introduced to fail fast during testing. + /// Otherwise this should always be true and not turned off. + private let retryConnectionEstablishment: Bool + private let preWarmedConnectionCount: Int init( idGenerator: Connection.ID.Generator, maximumConcurrentConnections: Int, + retryConnectionEstablishment: Bool, + maximumConnectionUses: Int?, + preWarmedHTTP1ConnectionCount: Int, lifecycleState: StateMachine.LifecycleState ) { self.connections = HTTP1Connections( maximumConcurrentConnections: maximumConcurrentConnections, - generator: idGenerator + generator: idGenerator, + maximumConnectionUses: maximumConnectionUses, + preWarmedHTTP1ConnectionCount: preWarmedHTTP1ConnectionCount ) + self.preWarmedConnectionCount = preWarmedHTTP1ConnectionCount + self.retryConnectionEstablishment = retryConnectionEstablishment self.requests = RequestQueue() self.lifecycleState = lifecycleState @@ -72,7 +84,10 @@ extension HTTPConnectionPool { requests: RequestQueue ) -> ConnectionMigrationAction { precondition(self.connections.isEmpty, "expected an empty state machine but connections are not empty") - precondition(self.http2Connections == nil, "expected an empty state machine but http2Connections are not nil") + precondition( + self.http2Connections == nil, + "expected an empty state machine but http2Connections are not nil" + ) precondition(self.requests.isEmpty, "expected an empty state machine but requests are not empty") self.requests = requests @@ -92,7 +107,8 @@ extension HTTPConnectionPool { let createConnections = self.connections.createConnectionsAfterMigrationIfNeeded( requiredEventLoopOfPendingRequests: requests.requestCountGroupedByRequiredEventLoop(), - generalPurposeRequestCountGroupedByPreferredEventLoop: requests.generalPurposeRequestCountGroupedByPreferredEventLoop() + generalPurposeRequestCountGroupedByPreferredEventLoop: + requests.generalPurposeRequestCountGroupedByPreferredEventLoop() ) if !http2Connections.isEmpty { @@ -133,9 +149,26 @@ extension HTTPConnectionPool { private mutating func executeRequestOnPreferredEventLoop(_ request: Request, eventLoop: EventLoop) -> Action { if let connection = self.connections.leaseConnection(onPreferred: eventLoop) { + // Cool, a connection is available. If using this would put us below our needed extra set, we + // should create another. + let stats = self.connections.generalPurposeStats + let needExtraConnection = + stats.nonLeased < (self.requests.count + self.preWarmedConnectionCount) && self.connections.canGrow + let action: StateMachine.ConnectionAction + + if needExtraConnection { + action = .createConnectionAndCancelTimeoutTimer( + createdID: self.connections.createNewConnection(on: eventLoop), + on: eventLoop, + cancelTimerID: connection.id + ) + } else { + action = .cancelTimeoutTimer(connection.id) + } + return .init( request: .executeRequest(request, connection, cancelTimeout: false), - connection: .cancelTimeoutTimer(connection.id) + connection: action ) } @@ -217,12 +250,26 @@ extension HTTPConnectionPool { self.failedConsecutiveConnectionAttempts += 1 self.lastConnectFailure = error + // We don't care how many waiting requests we have at this point, we will schedule a + // retry. More tasks, may appear until the backoff has completed. The final + // decision about the retry will be made in `connectionCreationBackoffDone(_:)` + let eventLoop = self.connections.backoffNextConnectionAttempt(connectionID) + switch self.lifecycleState { case .running: - // We don't care how many waiting requests we have at this point, we will schedule a - // retry. More tasks, may appear until the backoff has completed. The final - // decision about the retry will be made in `connectionCreationBackoffDone(_:)` - let eventLoop = self.connections.backoffNextConnectionAttempt(connectionID) + guard self.retryConnectionEstablishment else { + guard let (index, _) = self.connections.failConnection(connectionID) else { + preconditionFailure( + "A connection attempt failed, that the state machine knows nothing about. Somewhere state was lost." + ) + } + self.connections.removeConnection(at: index) + + return .init( + request: self.failAllRequests(reason: error), + connection: .none + ) + } let backoff = calculateBackoff(failedAttempt: self.failedConsecutiveConnectionAttempts) return .init( @@ -241,6 +288,12 @@ extension HTTPConnectionPool { } } + mutating func waitingForConnectivity(_ error: Error, connectionID: Connection.ID) -> Action { + self.lastConnectFailure = error + + return .init(request: .none, connection: .none) + } + mutating func connectionCreationBackoffDone(_ connectionID: Connection.ID) -> Action { switch self.lifecycleState { case .running: @@ -263,14 +316,30 @@ extension HTTPConnectionPool { } } - mutating func connectionIdleTimeout(_ connectionID: Connection.ID) -> Action { + mutating func connectionIdleTimeout(_ connectionID: Connection.ID, on eventLoop: any EventLoop) -> Action { + // Don't close idle connections if we need pre-warmed connections. Instead, re-arm the idle timer. + // We still want the idle timers to make sure we eventually fall below the pre-warmed limit. + if self.preWarmedConnectionCount > 0 { + let stats = self.connections.generalPurposeStats + if stats.idle <= self.preWarmedConnectionCount { + return .init( + request: .none, + connection: .scheduleTimeoutTimer(connectionID, on: eventLoop) + ) + } + } + + // Ok, we do actually want the connection count to go down. guard let connection = self.connections.closeConnectionIfIdle(connectionID) else { // because of a race this connection (connection close runs against trigger of timeout) // was already removed from the state machine. return .none } - precondition(self.lifecycleState == .running, "If we are shutting down, we must not have any idle connections") + precondition( + self.lifecycleState == .running, + "If we are shutting down, we must not have any idle connections" + ) return .init( request: .none, @@ -317,9 +386,11 @@ extension HTTPConnectionPool { mutating func cancelRequest(_ requestID: Request.ID) -> Action { // 1. check requests in queue - if self.requests.remove(requestID) != nil { + if let request = self.requests.remove(requestID) { + // Use the last connection error to let the user know why the request was never scheduled + let error = self.lastConnectFailure ?? HTTPClientError.cancelled return .init( - request: .cancelRequestTimeout(requestID), + request: .failRequest(request, error, cancelTimeout: true), connection: .none ) } @@ -372,11 +443,16 @@ extension HTTPConnectionPool { ) -> EstablishedAction { switch self.lifecycleState { case .running: - switch context.use { - case .generalPurpose: - return self.nextActionForIdleGeneralPurposeConnection(at: index, context: context) - case .eventLoop: - return self.nextActionForIdleEventLoopConnection(at: index, context: context) + // Close the connection if it's expired. + if context.shouldBeClosed { + return self.nextActionForToBeClosedIdleConnection(at: index, context: context) + } else { + switch context.use { + case .generalPurpose: + return self.nextActionForIdleGeneralPurposeConnection(at: index, context: context) + case .eventLoop: + return self.nextActionForIdleEventLoopConnection(at: index, context: context) + } } case .shuttingDown(let unclean): assert(self.requests.isEmpty) @@ -401,28 +477,63 @@ extension HTTPConnectionPool { at index: Int, context: HTTP1Connections.IdleConnectionContext ) -> EstablishedAction { + var requestAction = HTTPConnectionPool.StateMachine.RequestAction.none + var parkedConnectionDetails: (HTTPConnectionPool.Connection.ID, any EventLoop)? = nil + // 1. Check if there are waiting requests in the general purpose queue if let request = self.requests.popFirst(for: nil) { - return .init( - request: .executeRequest(request, self.connections.leaseConnection(at: index), cancelTimeout: true), - connection: .none + requestAction = .executeRequest( + request, + self.connections.leaseConnection(at: index), + cancelTimeout: true ) } // 2. Check if there are waiting requests in the matching eventLoop queue - if let request = self.requests.popFirst(for: context.eventLoop) { - return .init( - request: .executeRequest(request, self.connections.leaseConnection(at: index), cancelTimeout: true), - connection: .none + if case .none = requestAction, let request = self.requests.popFirst(for: context.eventLoop) { + requestAction = .executeRequest( + request, + self.connections.leaseConnection(at: index), + cancelTimeout: true ) } // 3. Create a timeout timer to ensure the connection is closed if it is idle for too - // long. - let (connectionID, eventLoop) = self.connections.parkConnection(at: index) + // long, assuming we don't already have a use for it. + if case .none = requestAction { + parkedConnectionDetails = self.connections.parkConnection(at: index) + } + + // 4. We may need to create another connection to make sure we have enough pre-warmed ones. + // We need to do that if we have fewer non-leased connections than we need pre-warmed ones _and_ the pool can grow. + // Note that in this case we don't need to account for the number of pending requests, as that is 0: step 1 + // confirmed that. + let connectionAction: EstablishedConnectionAction + + if self.connections.generalPurposeStats.nonLeased < self.preWarmedConnectionCount + && self.connections.canGrow + { + // Re-use the event loop of the connection that just got created. + if let parkedConnectionDetails { + let newConnectionID = self.connections.createNewConnection(on: parkedConnectionDetails.1) + connectionAction = .scheduleTimeoutTimerAndCreateConnection( + timeoutID: parkedConnectionDetails.0, + newConnectionID: newConnectionID, + on: parkedConnectionDetails.1 + ) + } else { + let newConnectionID = self.connections.createNewConnection(on: context.eventLoop) + connectionAction = .createConnection(connectionID: newConnectionID, on: context.eventLoop) + } + } else if let parkedConnectionDetails { + connectionAction = .scheduleTimeoutTimer(parkedConnectionDetails.0, on: parkedConnectionDetails.1) + } else { + connectionAction = .none + } + return .init( - request: .none, - connection: .scheduleTimeoutTimer(connectionID, on: eventLoop) + request: requestAction, + connection: connectionAction ) } @@ -450,6 +561,37 @@ extension HTTPConnectionPool { ) } + private mutating func nextActionForToBeClosedIdleConnection( + at index: Int, + context: HTTP1Connections.IdleConnectionContext + ) -> EstablishedAction { + // Step 1: Tell the connection pool to drop what it knows about this object. + let connectionToClose = self.connections.closeConnection(at: index) + + // Step 2: Check whether we need a connection to replace this one. We do if we have fewer non-leased connections + // than we requests + minimumPrewarming count _and_ the pool can grow. Note that in many cases the above closure + // will have made some space, which is just fine. + let nonLeased = self.connections.generalPurposeStats.nonLeased + let neededNonLeased = self.requests.generalPurposeCount + self.preWarmedConnectionCount + + let connectionAction: EstablishedConnectionAction + if nonLeased < neededNonLeased && self.connections.canGrow { + // We re-use the EL of the connection we just closed. + let newConnectionID = self.connections.createNewConnection(on: connectionToClose.eventLoop) + connectionAction = .closeConnectionAndCreateConnection( + closeConnection: connectionToClose, + newConnectionID: newConnectionID, + on: connectionToClose.eventLoop + ) + } else { + connectionAction = .closeConnection(connectionToClose, isShutdown: .no) + } + return .init( + request: .none, + connection: connectionAction + ) + } + // MARK: Failed/Closed connection management private mutating func nextActionForFailedConnection( @@ -485,7 +627,10 @@ extension HTTPConnectionPool { at index: Int, context: HTTP1Connections.FailedConnectionContext ) -> Action { - if context.connectionsStartingForUseCase < self.requests.generalPurposeCount { + let needConnectionForRequest = + context.connectionsStartingForUseCase + < (self.requests.generalPurposeCount + self.preWarmedConnectionCount) + if needConnectionForRequest { // if we have more requests queued up, than we have starting connections, we should // create a new connection let (newConnectionID, newEventLoop) = self.connections.replaceConnection(at: index) @@ -515,9 +660,18 @@ extension HTTPConnectionPool { return .none } + private mutating func failAllRequests(reason error: Error) -> RequestAction { + let allRequests = self.requests.removeAll() + guard !allRequests.isEmpty else { + return .none + } + return .failRequestsAndCancelTimeouts(allRequests, error) + } + // MARK: HTTP2 - mutating func newHTTP2MaxConcurrentStreamsReceived(_ connectionID: Connection.ID, newMaxStreams: Int) -> Action { + mutating func newHTTP2MaxConcurrentStreamsReceived(_ connectionID: Connection.ID, newMaxStreams: Int) -> Action + { // The `http2Connections` are optional here: // Connections report events back to us, if they are in a shutdown that was // initiated by the state machine. For this reason this callback might be invoked @@ -619,6 +773,7 @@ extension HTTPConnectionPool.HTTP1StateMachine: CustomStringConvertible { let stats = self.connections.stats let queued = self.requests.count - return "connections: [connecting: \(stats.connecting) | backoff: \(stats.backingOff) | leased: \(stats.leased) | idle: \(stats.idle)], queued: \(queued)" + return + "connections: [connecting: \(stats.connecting) | backoff: \(stats.backingOff) | leased: \(stats.leased) | idle: \(stats.idle)], queued: \(queued)" } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP2Connections.swift b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP2Connections.swift index 7aa504d03..2a0e0cc80 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP2Connections.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP2Connections.swift @@ -18,12 +18,12 @@ extension HTTPConnectionPool { private struct HTTP2ConnectionState { private enum State { /// the pool is establishing a connection. Valid transitions are to: .backingOff, .active and .closed - case starting + case starting(maximumUses: Int?) /// the connection is waiting to retry to establish a connection. Valid transitions are to .closed. /// From .closed a new connection state must be created for a retry. case backingOff /// the connection is active and is able to run requests. Valid transitions are to: .draining and .closed - case active(Connection, maxStreams: Int, usedStreams: Int, lastIdle: NIODeadline) + case active(Connection, maxStreams: Int, usedStreams: Int, lastIdle: NIODeadline, remainingUses: Int?) /// the connection is active and is running requests. No new requests must be scheduled. /// Valid transitions to: .draining and .closed case draining(Connection, maxStreams: Int, usedStreams: Int) @@ -71,8 +71,12 @@ extension HTTPConnectionPool { /// A request can be scheduled on the connection var isAvailable: Bool { switch self.state { - case .active(_, let maxStreams, let usedStreams, _): - return usedStreams < maxStreams + case .active(_, let maxStreams, let usedStreams, _, let remainingUses): + if let remainingUses = remainingUses { + return usedStreams < maxStreams && remainingUses > 0 + } else { + return usedStreams < maxStreams + } case .starting, .backingOff, .draining, .closed: return false } @@ -82,7 +86,7 @@ extension HTTPConnectionPool { /// Every idle connection is available, but not every available connection is idle. var isIdle: Bool { switch self.state { - case .active(_, _, let usedStreams, _): + case .active(_, _, let usedStreams, _, _): return usedStreams == 0 case .starting, .backingOff, .draining, .closed: return false @@ -112,9 +116,19 @@ extension HTTPConnectionPool { case .active, .draining, .backingOff, .closed: preconditionFailure("Invalid state: \(self.state)") - case .starting: - self.state = .active(conn, maxStreams: maxStreams, usedStreams: 0, lastIdle: .now()) - return maxStreams + case .starting(let maxUses): + self.state = .active( + conn, + maxStreams: maxStreams, + usedStreams: 0, + lastIdle: .now(), + remainingUses: maxUses + ) + if let maxUses = maxUses { + return min(maxStreams, maxUses) + } else { + return maxStreams + } } } @@ -127,9 +141,20 @@ extension HTTPConnectionPool { case .starting, .backingOff, .closed: preconditionFailure("Invalid state for updating max concurrent streams: \(self.state)") - case .active(let conn, _, let usedStreams, let lastIdle): - self.state = .active(conn, maxStreams: maxStreams, usedStreams: usedStreams, lastIdle: lastIdle) - return max(maxStreams - usedStreams, 0) + case .active(let conn, _, let usedStreams, let lastIdle, let remainingUses): + self.state = .active( + conn, + maxStreams: maxStreams, + usedStreams: usedStreams, + lastIdle: lastIdle, + remainingUses: remainingUses + ) + let availableStreams = max(maxStreams - usedStreams, 0) + if let remainingUses = remainingUses { + return min(remainingUses, availableStreams) + } else { + return availableStreams + } case .draining(let conn, _, let usedStreams): self.state = .draining(conn, maxStreams: maxStreams, usedStreams: usedStreams) @@ -142,7 +167,7 @@ extension HTTPConnectionPool { case .starting, .backingOff, .closed: preconditionFailure("Invalid state for draining a connection: \(self.state)") - case .active(let conn, let maxStreams, let usedStreams, _): + case .active(let conn, let maxStreams, let usedStreams, _, _): self.state = .draining(conn, maxStreams: maxStreams, usedStreams: usedStreams) return conn.eventLoop @@ -162,10 +187,22 @@ extension HTTPConnectionPool { } } - mutating func fail() { + enum FailAction { + case removeConnection + case none + } + + mutating func fail() -> FailAction { switch self.state { - case .starting, .active, .backingOff, .draining: + case .starting: + // If the connection fails while we are starting it, the fail call raced with + // `failedToConnect` (promises are succeeded or failed before channel handler + // callbacks). let's keep the state in `starting`, so that `failedToConnect` can + // create a backoff timer. + return .none + case .active, .backingOff, .draining: self.state = .closed + return .removeConnection case .closed: preconditionFailure("Invalid state: \(self.state)") } @@ -176,10 +213,20 @@ extension HTTPConnectionPool { case .starting, .backingOff, .draining, .closed: preconditionFailure("Invalid state for leasing a stream: \(self.state)") - case .active(let conn, let maxStreams, var usedStreams, let lastIdle): + case .active(let conn, let maxStreams, var usedStreams, let lastIdle, let remainingUses): usedStreams += count precondition(usedStreams <= maxStreams, "tried to lease a connection which is not available") - self.state = .active(conn, maxStreams: maxStreams, usedStreams: usedStreams, lastIdle: lastIdle) + precondition( + remainingUses.map { $0 >= count } ?? true, + "tried to lease streams from a connection which does not have enough remaining streams" + ) + self.state = .active( + conn, + maxStreams: maxStreams, + usedStreams: usedStreams, + lastIdle: lastIdle, + remainingUses: remainingUses.map { $0 - count } + ) return conn } } @@ -191,14 +238,26 @@ extension HTTPConnectionPool { case .starting, .backingOff, .closed: preconditionFailure("Invalid state: \(self.state)") - case .active(let conn, let maxStreams, var usedStreams, var lastIdle): + case .active(let conn, let maxStreams, var usedStreams, var lastIdle, let remainingUses): precondition(usedStreams > 0, "we cannot release more streams than we have leased") usedStreams &-= 1 if usedStreams == 0 { lastIdle = .now() } - self.state = .active(conn, maxStreams: maxStreams, usedStreams: usedStreams, lastIdle: lastIdle) - return max(maxStreams &- usedStreams, 0) + + self.state = .active( + conn, + maxStreams: maxStreams, + usedStreams: usedStreams, + lastIdle: lastIdle, + remainingUses: remainingUses + ) + let availableStreams = max(maxStreams &- usedStreams, 0) + if let remainingUses = remainingUses { + return min(availableStreams, remainingUses) + } else { + return availableStreams + } case .draining(let conn, let maxStreams, var usedStreams): precondition(usedStreams > 0, "we cannot release more streams than we have leased") @@ -210,7 +269,7 @@ extension HTTPConnectionPool { mutating func close() -> Connection { switch self.state { - case .active(let conn, _, 0, _): + case .active(let conn, _, 0, _, _): self.state = .closed return conn @@ -247,7 +306,7 @@ extension HTTPConnectionPool { context.connectBackoff.append(self.connectionID) return .removeConnection - case .active(let connection, _, let usedStreams, _): + case .active(let connection, _, let usedStreams, _, _): precondition(usedStreams >= 0) if usedStreams == 0 { context.close.append(connection) @@ -262,7 +321,9 @@ extension HTTPConnectionPool { return .keepConnection case .closed: - preconditionFailure("Unexpected state for cleanup: Did not expect to have closed connections in the state machine.") + preconditionFailure( + "Unexpected state for cleanup: Did not expect to have closed connections in the state machine." + ) } } @@ -274,7 +335,7 @@ extension HTTPConnectionPool { case .backingOff: stats.backingOffConnections &+= 1 - case .active(_, let maxStreams, let usedStreams, _): + case .active(_, let maxStreams, let usedStreams, _, _): stats.availableStreams += max(maxStreams - usedStreams, 0) stats.leasedStreams += usedStreams stats.availableConnections &+= 1 @@ -304,7 +365,7 @@ extension HTTPConnectionPool { context.starting.append((self.connectionID, self.eventLoop)) return .removeConnection - case .active(let connection, _, let usedStreams, _): + case .active(let connection, _, let usedStreams, _, _): precondition(usedStreams >= 0) if usedStreams == 0 { context.close.append(connection) @@ -321,14 +382,16 @@ extension HTTPConnectionPool { return .removeConnection case .closed: - preconditionFailure("Unexpected state: Did not expect to have connections with this state in the state machine: \(self.state)") + preconditionFailure( + "Unexpected state: Did not expect to have connections with this state in the state machine: \(self.state)" + ) } } - init(connectionID: Connection.ID, eventLoop: EventLoop) { + init(connectionID: Connection.ID, eventLoop: EventLoop, maximumUses: Int?) { self.connectionID = connectionID self.eventLoop = eventLoop - self.state = .starting + self.state = .starting(maximumUses: maximumUses) } } @@ -337,6 +400,8 @@ extension HTTPConnectionPool { private let generator: Connection.ID.Generator /// The connections states private var connections: [HTTP2ConnectionState] + /// The number of times each connection can be used before it is closed and replaced. + private let maximumConnectionUses: Int? var isEmpty: Bool { self.connections.isEmpty @@ -348,9 +413,10 @@ extension HTTPConnectionPool { } } - init(generator: Connection.ID.Generator) { + init(generator: Connection.ID.Generator, maximumConnectionUses: Int?) { self.generator = generator self.connections = [] + self.maximumConnectionUses = maximumConnectionUses } // MARK: Migration @@ -365,12 +431,20 @@ extension HTTPConnectionPool { backingOff: [(Connection.ID, EventLoop)] ) { for (connectionID, eventLoop) in starting { - let newConnection = HTTP2ConnectionState(connectionID: connectionID, eventLoop: eventLoop) + let newConnection = HTTP2ConnectionState( + connectionID: connectionID, + eventLoop: eventLoop, + maximumUses: self.maximumConnectionUses + ) self.connections.append(newConnection) } for (connectionID, eventLoop) in backingOff { - var backingOffConnection = HTTP2ConnectionState(connectionID: connectionID, eventLoop: eventLoop) + var backingOffConnection = HTTP2ConnectionState( + connectionID: connectionID, + eventLoop: eventLoop, + maximumUses: self.maximumConnectionUses + ) // TODO: Maybe we want to add a static init for backing off connections to HTTP2ConnectionState backingOffConnection.failedToConnect() self.connections.append(backingOffConnection) @@ -476,7 +550,11 @@ extension HTTPConnectionPool { "we should not create more than one connection per event loop" ) - let connection = HTTP2ConnectionState(connectionID: self.generator.next(), eventLoop: eventLoop) + let connection = HTTP2ConnectionState( + connectionID: self.generator.next(), + eventLoop: eventLoop, + maximumUses: self.maximumConnectionUses + ) self.connections.append(connection) return connection.connectionID } @@ -489,11 +567,17 @@ extension HTTPConnectionPool { /// - Returns: An index and an ``EstablishedConnectionContext`` to determine the next action for the now idle connection. /// Call ``leaseStreams(at:count:)`` or ``closeConnection(at:)`` with the supplied index after /// this. - mutating func newHTTP2ConnectionEstablished(_ connection: Connection, maxConcurrentStreams: Int) -> (Int, EstablishedConnectionContext) { + mutating func newHTTP2ConnectionEstablished( + _ connection: Connection, + maxConcurrentStreams: Int + ) -> (Int, EstablishedConnectionContext) { guard let index = self.connections.firstIndex(where: { $0.connectionID == connection.id }) else { preconditionFailure("There is a new connection that we didn't request!") } - precondition(connection.eventLoop === self.connections[index].eventLoop, "Expected the new connection to be on EL") + precondition( + connection.eventLoop === self.connections[index].eventLoop, + "Expected the new connection to be on EL" + ) let availableStreams = self.connections[index].connected(connection, maxStreams: maxConcurrentStreams) let context = EstablishedConnectionContext( availableStreams: availableStreams, @@ -661,7 +745,8 @@ extension HTTPConnectionPool { precondition(self.connections[index].isClosed) let newConnection = HTTP2ConnectionState( connectionID: self.generator.next(), - eventLoop: self.connections[index].eventLoop + eventLoop: self.connections[index].eventLoop, + maximumUses: self.maximumConnectionUses ) self.connections[index] = newConnection @@ -676,10 +761,16 @@ extension HTTPConnectionPool { // must ignore the event. return nil } - self.connections[index].fail() - let eventLoop = self.connections[index].eventLoop - let context = FailedConnectionContext(eventLoop: eventLoop) - return (index, context) + + switch self.connections[index].fail() { + case .none: + return nil + + case .removeConnection: + let eventLoop = self.connections[index].eventLoop + let context = FailedConnectionContext(eventLoop: eventLoop) + return (index, context) + } } mutating func shutdown() -> CleanupContext { diff --git a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP2StateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP2StateMachine.swift index d3e6fbdcd..67a07e6dd 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP2StateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP2StateMachine.swift @@ -18,6 +18,7 @@ import NIOHTTP2 extension HTTPConnectionPool { struct HTTP2StateMachine { typealias Action = HTTPConnectionPool.StateMachine.Action + typealias RequestAction = HTTPConnectionPool.StateMachine.RequestAction typealias ConnectionMigrationAction = HTTPConnectionPool.StateMachine.ConnectionMigrationAction typealias EstablishedAction = HTTPConnectionPool.StateMachine.EstablishedAction typealias EstablishedConnectionAction = HTTPConnectionPool.StateMachine.EstablishedConnectionAction @@ -33,16 +34,25 @@ extension HTTPConnectionPool { private let idGenerator: Connection.ID.Generator private(set) var lifecycleState: StateMachine.LifecycleState + /// The property was introduced to fail fast during testing. + /// Otherwise this should always be true and not turned off. + private let retryConnectionEstablishment: Bool init( idGenerator: Connection.ID.Generator, - lifecycleState: StateMachine.LifecycleState + retryConnectionEstablishment: Bool, + lifecycleState: StateMachine.LifecycleState, + maximumConnectionUses: Int? ) { self.idGenerator = idGenerator self.requests = RequestQueue() - self.connections = HTTP2Connections(generator: idGenerator) + self.connections = HTTP2Connections( + generator: idGenerator, + maximumConnectionUses: maximumConnectionUses + ) self.lifecycleState = lifecycleState + self.retryConnectionEstablishment = retryConnectionEstablishment } mutating func migrateFromHTTP1( @@ -75,7 +85,10 @@ extension HTTPConnectionPool { requests: RequestQueue ) -> ConnectionMigrationAction { precondition(self.connections.isEmpty, "expected an empty state machine but connections are not empty") - precondition(self.http1Connections == nil, "expected an empty state machine but http1Connections are not nil") + precondition( + self.http1Connections == nil, + "expected an empty state machine but http1Connections are not nil" + ) precondition(self.requests.isEmpty, "expected an empty state machine but requests are not empty") self.requests = requests @@ -85,7 +98,7 @@ extension HTTPConnectionPool { self.connections = http2Connections } - var http1Connections = http1Connections // make http1Connections mutable + var http1Connections = http1Connections // make http1Connections mutable let context = http1Connections.migrateToHTTP2() self.connections.migrateFromHTTP1( starting: context.starting, @@ -207,23 +220,24 @@ extension HTTPConnectionPool { .init(self._newHTTP2ConnectionEstablished(connection, maxConcurrentStreams: maxConcurrentStreams)) } - private mutating func _newHTTP2ConnectionEstablished(_ connection: Connection, maxConcurrentStreams: Int) -> EstablishedAction { + private mutating func _newHTTP2ConnectionEstablished( + _ connection: Connection, + maxConcurrentStreams: Int + ) -> EstablishedAction { self.failedConsecutiveConnectionAttempts = 0 self.lastConnectFailure = nil - if self.connections.hasActiveConnection(for: connection.eventLoop) { - guard let (index, _) = self.connections.failConnection(connection.id) else { - preconditionFailure("we have established a new connection that we know nothing about?") - } - self.connections.removeConnection(at: index) + let doesConnectionExistsForEL = self.connections.hasActiveConnection(for: connection.eventLoop) + let (index, context) = self.connections.newHTTP2ConnectionEstablished( + connection, + maxConcurrentStreams: maxConcurrentStreams + ) + if doesConnectionExistsForEL { + let connection = self.connections.closeConnection(at: index) return .init( request: .none, connection: .closeConnection(connection, isShutdown: .no) ) } else { - let (index, context) = self.connections.newHTTP2ConnectionEstablished( - connection, - maxConcurrentStreams: maxConcurrentStreams - ) return self.nextActionForAvailableConnection(at: index, context: context) } } @@ -288,8 +302,14 @@ extension HTTPConnectionPool { } } - mutating func newHTTP2MaxConcurrentStreamsReceived(_ connectionID: Connection.ID, newMaxStreams: Int) -> Action { - guard let (index, context) = self.connections.newHTTP2MaxConcurrentStreamsReceived(connectionID, newMaxStreams: newMaxStreams) else { + mutating func newHTTP2MaxConcurrentStreamsReceived(_ connectionID: Connection.ID, newMaxStreams: Int) -> Action + { + guard + let (index, context) = self.connections.newHTTP2MaxConcurrentStreamsReceived( + connectionID, + newMaxStreams: newMaxStreams + ) + else { // When a connection close is initiated by the connection pool, the connection will // still report further events (like newMaxConcurrentStreamsReceived) to the state // machine. In those cases we must ignore the event. @@ -333,15 +353,15 @@ extension HTTPConnectionPool { // we need to start a new on connection in two cases: let needGeneralPurposeConnection = // 1. if we have general purpose requests - !self.requests.isEmpty(for: nil) && + !self.requests.isEmpty(for: nil) // and no connection starting or active - !context.hasGeneralPurposeConnection + && !context.hasGeneralPurposeConnection let needRequiredEventLoopConnection = // 2. or if we have requests for a required event loop - !self.requests.isEmpty(for: eventLoop) && + !self.requests.isEmpty(for: eventLoop) // and no connection starting or active for the given event loop - !context.hasConnectionOnSpecifiedEventLoop + && !context.hasConnectionOnSpecifiedEventLoop guard needGeneralPurposeConnection || needRequiredEventLoopConnection else { // otherwise we can remove the connection @@ -349,7 +369,8 @@ extension HTTPConnectionPool { return .none } - let (newConnectionID, previousEventLoop) = self.connections.createNewConnectionByReplacingClosedConnection(at: index) + let (newConnectionID, previousEventLoop) = self.connections + .createNewConnectionByReplacingClosedConnection(at: index) precondition(previousEventLoop === eventLoop) return .init( @@ -402,8 +423,44 @@ extension HTTPConnectionPool { self.lastConnectFailure = error let eventLoop = self.connections.backoffNextConnectionAttempt(connectionID) - let backoff = calculateBackoff(failedAttempt: self.failedConsecutiveConnectionAttempts) - return .init(request: .none, connection: .scheduleBackoffTimer(connectionID, backoff: backoff, on: eventLoop)) + + switch self.lifecycleState { + case .running: + guard self.retryConnectionEstablishment else { + guard let (index, _) = self.connections.failConnection(connectionID) else { + preconditionFailure( + "A connection attempt failed, that the state machine knows nothing about. Somewhere state was lost." + ) + } + self.connections.removeConnection(at: index) + + return .init( + request: self.failAllRequests(reason: error), + connection: .none + ) + } + + let backoff = calculateBackoff(failedAttempt: self.failedConsecutiveConnectionAttempts) + return .init( + request: .none, + connection: .scheduleBackoffTimer(connectionID, backoff: backoff, on: eventLoop) + ) + case .shuttingDown: + guard let (index, context) = self.connections.failConnection(connectionID) else { + preconditionFailure( + "A connection attempt failed, that the state machine knows nothing about. Somewhere state was lost." + ) + } + return self.nextActionForFailedConnection(at: index, on: context.eventLoop) + case .shutDown: + preconditionFailure("If the pool is already shutdown, all connections must have been torn down.") + } + } + + mutating func waitingForConnectivity(_ error: Error, connectionID: Connection.ID) -> Action { + self.lastConnectFailure = error + + return .init(request: .none, connection: .none) } mutating func connectionCreationBackoffDone(_ connectionID: Connection.ID) -> Action { @@ -416,6 +473,14 @@ extension HTTPConnectionPool { return self.nextActionForFailedConnection(at: index, on: context.eventLoop) } + private mutating func failAllRequests(reason error: Error) -> RequestAction { + let allRequests = self.requests.removeAll() + guard !allRequests.isEmpty else { + return .none + } + return .failRequestsAndCancelTimeouts(allRequests, error) + } + mutating func timeoutRequest(_ requestID: Request.ID) -> Action { // 1. check requests in queue if let request = self.requests.remove(requestID) { @@ -439,9 +504,11 @@ extension HTTPConnectionPool { mutating func cancelRequest(_ requestID: Request.ID) -> Action { // 1. check requests in queue - if self.requests.remove(requestID) != nil { + if let request = self.requests.remove(requestID) { + // Use the last connection error to let the user know why the request was never scheduled + let error = self.lastConnectFailure ?? HTTPClientError.cancelled return .init( - request: .cancelRequestTimeout(requestID), + request: .failRequest(request, error, cancelTimeout: true), connection: .none ) } @@ -459,7 +526,10 @@ extension HTTPConnectionPool { return .none } - precondition(self.lifecycleState == .running, "If we are shutting down, we must not have any idle connections") + precondition( + self.lifecycleState == .running, + "If we are shutting down, we must not have any idle connections" + ) return .init( request: .none, @@ -512,7 +582,10 @@ extension HTTPConnectionPool { case .shuttingDown(let unclean): if self.connections.isEmpty { // if the http2connections are empty as well, there are no more connections. Shutdown completed. - return .init(request: .none, connection: .closeConnection(connection, isShutdown: .yes(unclean: unclean))) + return .init( + request: .none, + connection: .closeConnection(connection, isShutdown: .yes(unclean: unclean)) + ) } else { return .init(request: .none, connection: .closeConnection(connection, isShutdown: .no)) } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+StateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+StateMachine.swift index 4d912633c..b905e26bd 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+StateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+StateMachine.swift @@ -26,7 +26,9 @@ extension HTTPConnectionPool { self.connection = connection } - static let none = Action(request: .none, connection: .none) + static var none: Action { + Action(request: .none, connection: .none) + } } enum ConnectionAction { @@ -40,9 +42,24 @@ extension HTTPConnectionPool { case scheduleTimeoutTimer(Connection.ID, on: EventLoop) case cancelTimeoutTimer(Connection.ID) + case createConnectionAndCancelTimeoutTimer( + createdID: Connection.ID, + on: EventLoop, + cancelTimerID: Connection.ID + ) + case scheduleTimeoutTimerAndCreateConnection( + timeoutID: Connection.ID, + newConnectionID: Connection.ID, + on: EventLoop + ) case closeConnection(Connection, isShutdown: IsShutdown) case cleanupConnections(CleanupContext, isShutdown: IsShutdown) + case closeConnectionAndCreateConnection( + closeConnection: Connection, + newConnectionID: Connection.ID, + on: EventLoop + ) case migration( createConnections: [(Connection.ID, EventLoop)], @@ -61,7 +78,6 @@ extension HTTPConnectionPool { case failRequestsAndCancelTimeouts([Request], Error) case scheduleRequestTimeout(for: Request, on: EventLoop) - case cancelRequestTimeout(Request.ID) case none } @@ -97,24 +113,56 @@ extension HTTPConnectionPool { let idGenerator: Connection.ID.Generator let maximumConcurrentHTTP1Connections: Int - - init(idGenerator: Connection.ID.Generator, maximumConcurrentHTTP1Connections: Int) { + /// The property was introduced to fail fast during testing. + /// Otherwise this should always be true and not turned off. + private let retryConnectionEstablishment: Bool + let maximumConnectionUses: Int? + let preWarmedHTTP1ConnectionCount: Int + + init( + idGenerator: Connection.ID.Generator, + maximumConcurrentHTTP1Connections: Int, + retryConnectionEstablishment: Bool, + preferHTTP1: Bool, + maximumConnectionUses: Int?, + preWarmedHTTP1ConnectionCount: Int + ) { self.maximumConcurrentHTTP1Connections = maximumConcurrentHTTP1Connections + self.retryConnectionEstablishment = retryConnectionEstablishment self.idGenerator = idGenerator - let http1State = HTTP1StateMachine( - idGenerator: idGenerator, - maximumConcurrentConnections: maximumConcurrentHTTP1Connections, - lifecycleState: .running - ) - self.state = .http1(http1State) + self.maximumConnectionUses = maximumConnectionUses + self.preWarmedHTTP1ConnectionCount = preWarmedHTTP1ConnectionCount + + if preferHTTP1 { + let http1State = HTTP1StateMachine( + idGenerator: idGenerator, + maximumConcurrentConnections: maximumConcurrentHTTP1Connections, + retryConnectionEstablishment: retryConnectionEstablishment, + maximumConnectionUses: maximumConnectionUses, + preWarmedHTTP1ConnectionCount: preWarmedHTTP1ConnectionCount, + lifecycleState: .running + ) + self.state = .http1(http1State) + } else { + let http2State = HTTP2StateMachine( + idGenerator: idGenerator, + retryConnectionEstablishment: retryConnectionEstablishment, + lifecycleState: .running, + maximumConnectionUses: maximumConnectionUses + ) + self.state = .http2(http2State) + } } mutating func executeRequest(_ request: Request) -> Action { - self.state.modify(http1: { http1 in - http1.executeRequest(request) - }, http2: { http2 in - http2.executeRequest(request) - }) + self.state.modify( + http1: { http1 in + http1.executeRequest(request) + }, + http2: { http2 in + http2.executeRequest(request) + } + ) } mutating func newHTTP1ConnectionCreated(_ connection: Connection) -> Action { @@ -128,6 +176,9 @@ extension HTTPConnectionPool { var http1StateMachine = HTTP1StateMachine( idGenerator: self.idGenerator, maximumConcurrentConnections: self.maximumConcurrentHTTP1Connections, + retryConnectionEstablishment: self.retryConnectionEstablishment, + maximumConnectionUses: self.maximumConnectionUses, + preWarmedHTTP1ConnectionCount: self.preWarmedHTTP1ConnectionCount, lifecycleState: http2StateMachine.lifecycleState ) @@ -148,7 +199,9 @@ extension HTTPConnectionPool { var http2StateMachine = HTTP2StateMachine( idGenerator: self.idGenerator, - lifecycleState: http1StateMachine.lifecycleState + retryConnectionEstablishment: self.retryConnectionEstablishment, + lifecycleState: http1StateMachine.lifecycleState, + maximumConnectionUses: self.maximumConnectionUses ) let migrationAction = http2StateMachine.migrateFromHTTP1( http1Connections: http1StateMachine.connections, @@ -171,52 +224,82 @@ extension HTTPConnectionPool { } } - mutating func newHTTP2MaxConcurrentStreamsReceived(_ connectionID: Connection.ID, newMaxStreams: Int) -> Action { - self.state.modify(http1: { http1 in - http1.newHTTP2MaxConcurrentStreamsReceived(connectionID, newMaxStreams: newMaxStreams) - }, http2: { http2 in - http2.newHTTP2MaxConcurrentStreamsReceived(connectionID, newMaxStreams: newMaxStreams) - }) + mutating func newHTTP2MaxConcurrentStreamsReceived(_ connectionID: Connection.ID, newMaxStreams: Int) -> Action + { + self.state.modify( + http1: { http1 in + http1.newHTTP2MaxConcurrentStreamsReceived(connectionID, newMaxStreams: newMaxStreams) + }, + http2: { http2 in + http2.newHTTP2MaxConcurrentStreamsReceived(connectionID, newMaxStreams: newMaxStreams) + } + ) } mutating func http2ConnectionGoAwayReceived(_ connectionID: Connection.ID) -> Action { - self.state.modify(http1: { http1 in - http1.http2ConnectionGoAwayReceived(connectionID) - }, http2: { http2 in - http2.http2ConnectionGoAwayReceived(connectionID) - }) + self.state.modify( + http1: { http1 in + http1.http2ConnectionGoAwayReceived(connectionID) + }, + http2: { http2 in + http2.http2ConnectionGoAwayReceived(connectionID) + } + ) } mutating func http2ConnectionClosed(_ connectionID: Connection.ID) -> Action { - self.state.modify(http1: { http1 in - http1.http2ConnectionClosed(connectionID) - }, http2: { http2 in - http2.http2ConnectionClosed(connectionID) - }) + self.state.modify( + http1: { http1 in + http1.http2ConnectionClosed(connectionID) + }, + http2: { http2 in + http2.http2ConnectionClosed(connectionID) + } + ) } mutating func http2ConnectionStreamClosed(_ connectionID: Connection.ID) -> Action { - self.state.modify(http1: { http1 in - http1.http2ConnectionStreamClosed(connectionID) - }, http2: { http2 in - http2.http2ConnectionStreamClosed(connectionID) - }) + self.state.modify( + http1: { http1 in + http1.http2ConnectionStreamClosed(connectionID) + }, + http2: { http2 in + http2.http2ConnectionStreamClosed(connectionID) + } + ) } mutating func failedToCreateNewConnection(_ error: Error, connectionID: Connection.ID) -> Action { - self.state.modify(http1: { http1 in - http1.failedToCreateNewConnection(error, connectionID: connectionID) - }, http2: { http2 in - http2.failedToCreateNewConnection(error, connectionID: connectionID) - }) + self.state.modify( + http1: { http1 in + http1.failedToCreateNewConnection(error, connectionID: connectionID) + }, + http2: { http2 in + http2.failedToCreateNewConnection(error, connectionID: connectionID) + } + ) + } + + mutating func waitingForConnectivity(_ error: Error, connectionID: Connection.ID) -> Action { + self.state.modify( + http1: { http1 in + http1.waitingForConnectivity(error, connectionID: connectionID) + }, + http2: { http2 in + http2.waitingForConnectivity(error, connectionID: connectionID) + } + ) } mutating func connectionCreationBackoffDone(_ connectionID: Connection.ID) -> Action { - self.state.modify(http1: { http1 in - http1.connectionCreationBackoffDone(connectionID) - }, http2: { http2 in - http2.connectionCreationBackoffDone(connectionID) - }) + self.state.modify( + http1: { http1 in + http1.connectionCreationBackoffDone(connectionID) + }, + http2: { http2 in + http2.connectionCreationBackoffDone(connectionID) + } + ) } /// A request has timed out. @@ -225,11 +308,14 @@ extension HTTPConnectionPool { /// request, but don't need to cancel the timer (it already triggered). If a request is cancelled /// we don't need to fail it but we need to cancel its timeout timer. mutating func timeoutRequest(_ requestID: Request.ID) -> Action { - self.state.modify(http1: { http1 in - http1.timeoutRequest(requestID) - }, http2: { http2 in - http2.timeoutRequest(requestID) - }) + self.state.modify( + http1: { http1 in + http1.timeoutRequest(requestID) + }, + http2: { http2 in + http2.timeoutRequest(requestID) + } + ) } /// A request was cancelled. @@ -238,44 +324,59 @@ extension HTTPConnectionPool { /// need to cancel its timeout timer. If a request times out, we need to fail the request, but don't /// need to cancel the timer (it already triggered). mutating func cancelRequest(_ requestID: Request.ID) -> Action { - self.state.modify(http1: { http1 in - http1.cancelRequest(requestID) - }, http2: { http2 in - http2.cancelRequest(requestID) - }) + self.state.modify( + http1: { http1 in + http1.cancelRequest(requestID) + }, + http2: { http2 in + http2.cancelRequest(requestID) + } + ) } - mutating func connectionIdleTimeout(_ connectionID: Connection.ID) -> Action { - self.state.modify(http1: { http1 in - http1.connectionIdleTimeout(connectionID) - }, http2: { http2 in - http2.connectionIdleTimeout(connectionID) - }) + mutating func connectionIdleTimeout(_ connectionID: Connection.ID, on eventLoop: any EventLoop) -> Action { + self.state.modify( + http1: { http1 in + http1.connectionIdleTimeout(connectionID, on: eventLoop) + }, + http2: { http2 in + http2.connectionIdleTimeout(connectionID) + } + ) } /// A connection has been closed mutating func http1ConnectionClosed(_ connectionID: Connection.ID) -> Action { - self.state.modify(http1: { http1 in - http1.http1ConnectionClosed(connectionID) - }, http2: { http2 in - http2.http1ConnectionClosed(connectionID) - }) + self.state.modify( + http1: { http1 in + http1.http1ConnectionClosed(connectionID) + }, + http2: { http2 in + http2.http1ConnectionClosed(connectionID) + } + ) } mutating func http1ConnectionReleased(_ connectionID: Connection.ID) -> Action { - self.state.modify(http1: { http1 in - http1.http1ConnectionReleased(connectionID) - }, http2: { http2 in - http2.http1ConnectionReleased(connectionID) - }) + self.state.modify( + http1: { http1 in + http1.http1ConnectionReleased(connectionID) + }, + http2: { http2 in + http2.http1ConnectionReleased(connectionID) + } + ) } mutating func shutdown() -> Action { - return self.state.modify(http1: { http1 in - http1.shutdown() - }, http2: { http2 in - http2.shutdown() - }) + self.state.modify( + http1: { http1 in + http1.shutdown() + }, + http2: { http2 in + http2.shutdown() + } + ) } } } @@ -318,7 +419,9 @@ extension HTTPConnectionPool.StateMachine { } struct EstablishedAction { - static let none: Self = .init(request: .none, connection: .none) + static var none: Self { + Self(request: .none, connection: .none) + } let request: HTTPConnectionPool.StateMachine.RequestAction let connection: EstablishedConnectionAction } @@ -326,7 +429,24 @@ extension HTTPConnectionPool.StateMachine { enum EstablishedConnectionAction { case none case scheduleTimeoutTimer(HTTPConnectionPool.Connection.ID, on: EventLoop) - case closeConnection(HTTPConnectionPool.Connection, isShutdown: HTTPConnectionPool.StateMachine.ConnectionAction.IsShutdown) + case closeConnection( + HTTPConnectionPool.Connection, + isShutdown: HTTPConnectionPool.StateMachine.ConnectionAction.IsShutdown + ) + case scheduleTimeoutTimerAndCreateConnection( + timeoutID: HTTPConnectionPool.Connection.ID, + newConnectionID: HTTPConnectionPool.Connection.ID, + on: EventLoop + ) + case closeConnectionAndCreateConnection( + closeConnection: HTTPConnectionPool.Connection, + newConnectionID: HTTPConnectionPool.Connection.ID, + on: EventLoop + ) + case createConnection( + connectionID: HTTPConnectionPool.Connection.ID, + on: EventLoop + ) } } @@ -348,6 +468,24 @@ extension HTTPConnectionPool.StateMachine.ConnectionAction { self = .scheduleTimeoutTimer(connectionID, on: eventLoop) case .closeConnection(let connection, let isShutdown): self = .closeConnection(connection, isShutdown: isShutdown) + case .closeConnectionAndCreateConnection( + let closeConnection, + let newConnectionID, + let eventLoop + ): + self = .closeConnectionAndCreateConnection( + closeConnection: closeConnection, + newConnectionID: newConnectionID, + on: eventLoop + ) + case .scheduleTimeoutTimerAndCreateConnection(let timeoutID, let newConnectionID, let eventLoop): + self = .scheduleTimeoutTimerAndCreateConnection( + timeoutID: timeoutID, + newConnectionID: newConnectionID, + on: eventLoop + ) + case .createConnection(let connectionID, on: let eventLoop): + self = .createConnection(connectionID, on: eventLoop) } } } @@ -358,28 +496,32 @@ extension HTTPConnectionPool.StateMachine.ConnectionAction { _ establishedAction: HTTPConnectionPool.StateMachine.EstablishedConnectionAction ) -> Self { switch establishedAction { - case .none: + case .none, .createConnection: + // createConnection can only come from the HTTP/1 pool, so we only see this when + // migrating to HTTP/2. We can ignore it there: we already have a connection to use. return .migration( createConnections: migrationAction.createConnections, closeConnections: migrationAction.closeConnections, scheduleTimeout: nil ) + case .closeConnectionAndCreateConnection( + closeConnection: let connection, + newConnectionID: _, + on: _ + ): + // This event can only come _from_ the HTTP/1 pool, migrating to HTTP/2. We do not do prewarmed HTTP/2 connections, + // so we can ignore the request for a new connection. This is thus the same as the case below. + return Self.closeConnection(connection, isShutdown: .no, migrationAction: migrationAction) case .closeConnection(let connection, let isShutdown): - guard isShutdown == .no else { - precondition( - migrationAction.closeConnections.isEmpty && - migrationAction.createConnections.isEmpty, - "migration actions are not supported during shutdown" - ) - return .closeConnection(connection, isShutdown: isShutdown) - } - var closeConnections = migrationAction.closeConnections - closeConnections.append(connection) - return .migration( - createConnections: migrationAction.createConnections, - closeConnections: closeConnections, - scheduleTimeout: nil - ) + return Self.closeConnection(connection, isShutdown: isShutdown, migrationAction: migrationAction) + case .scheduleTimeoutTimerAndCreateConnection( + timeoutID: let connectionID, + newConnectionID: _, + on: let eventLoop + ): + // This event can only come _from_ the HTTP/1 pool, migrating to HTTP/2. We do not do prewarmed HTTP/2 connections, + // so we can ignore the request for a new connection. This is thus the same as the case below. + fallthrough case .scheduleTimeoutTimer(let connectionID, let eventLoop): return .migration( createConnections: migrationAction.createConnections, @@ -388,4 +530,25 @@ extension HTTPConnectionPool.StateMachine.ConnectionAction { ) } } + + private static func closeConnection( + _ connection: HTTPConnectionPool.Connection, + isShutdown: HTTPConnectionPool.StateMachine.ConnectionAction.IsShutdown, + migrationAction: HTTPConnectionPool.StateMachine.ConnectionMigrationAction + ) -> Self { + guard isShutdown == .no else { + precondition( + migrationAction.closeConnections.isEmpty && migrationAction.createConnections.isEmpty, + "migration actions are not supported during shutdown" + ) + return .closeConnection(connection, isShutdown: isShutdown) + } + var closeConnections = migrationAction.closeConnections + closeConnections.append(connection) + return .migration( + createConnections: migrationAction.createConnections, + closeConnections: closeConnections, + scheduleTimeout: nil + ) + } } diff --git a/Sources/AsyncHTTPClient/DeconstructedURL.swift b/Sources/AsyncHTTPClient/DeconstructedURL.swift index 020c17455..f7d0b1977 100644 --- a/Sources/AsyncHTTPClient/DeconstructedURL.swift +++ b/Sources/AsyncHTTPClient/DeconstructedURL.swift @@ -48,9 +48,16 @@ extension DeconstructedURL { switch scheme { case .http, .https: + #if !canImport(Darwin) + guard let urlHost = url.host, !urlHost.isEmpty else { + throw HTTPClientError.emptyHost + } + let host = urlHost.trimIPv6Brackets() + #else guard let host = url.host, !host.isEmpty else { throw HTTPClientError.emptyHost } + #endif self.init( scheme: scheme, connectionTarget: .init(remoteHost: host, port: url.port ?? scheme.defaultPort), @@ -81,3 +88,26 @@ extension DeconstructedURL { } } } + +#if !canImport(Darwin) +extension String { + @inlinable internal func trimIPv6Brackets() -> String { + var utf8View = self.utf8[...] + + var modified = false + if utf8View.first == UInt8(ascii: "[") { + utf8View = utf8View.dropFirst() + modified = true + } + if utf8View.last == UInt8(ascii: "]") { + utf8View = utf8View.dropLast() + modified = true + } + + if modified { + return String(Substring(utf8View)) + } + return self + } +} +#endif diff --git a/Sources/AsyncHTTPClient/Docs.docc/index.md b/Sources/AsyncHTTPClient/Docs.docc/index.md new file mode 100644 index 000000000..37033e043 --- /dev/null +++ b/Sources/AsyncHTTPClient/Docs.docc/index.md @@ -0,0 +1,325 @@ +# ``AsyncHTTPClient`` + +This package provides simple HTTP Client library built on top of SwiftNIO. + +## Overview + +This library provides the following: +- First class support for Swift Concurrency (since version 1.9.0) +- Asynchronous and non-blocking request methods +- Simple follow-redirects (cookie headers are dropped) +- Streaming body download +- TLS support +- Automatic HTTP/2 over HTTPS (since version 1.7.0) +- Cookie parsing (but not storage) + +### Getting Started + +#### Adding the dependency + +Add the following entry in your Package.swift to start using HTTPClient: + +```swift +.package(url: "https://github.com/swift-server/async-http-client.git", from: "1.9.0") +``` +and `AsyncHTTPClient` dependency to your target: +```swift +.target(name: "MyApp", dependencies: [.product(name: "AsyncHTTPClient", package: "async-http-client")]), +``` + +#### Request-Response API + +The code snippet below illustrates how to make a simple GET request to a remote server. + +```swift +import AsyncHTTPClient + +/// MARK: - Using Swift Concurrency +let request = HTTPClientRequest(url: "https://apple.com/") +let response = try await httpClient.execute(request, timeout: .seconds(30)) +print("HTTP head", response) +if response.status == .ok { + let body = try await response.body.collect(upTo: 1024 * 1024) // 1 MB + // handle body +} else { + // handle remote error +} + + +/// MARK: - Using SwiftNIO EventLoopFuture +HTTPClient.shared.get(url: "https://apple.com/").whenComplete { result in + switch result { + case .failure(let error): + // process error + case .success(let response): + if response.status == .ok { + // handle response + } else { + // handle remote error + } + } +} +``` + +You should always shut down ``HTTPClient`` instances you created using ``HTTPClient/shutdown()-9gcpw``. Please note that you must not call ``HTTPClient/shutdown()-9gcpw`` before all requests of the HTTP client have finished, or else the in-flight requests will likely fail because their network connections are interrupted. + +#### async/await examples + +Examples for the async/await API can be found in the [`Examples` folder](https://github.com/swift-server/async-http-client/tree/main/Examples) in the repository. + +### Usage guide + +The default HTTP Method is `GET`. In case you need to have more control over the method, or you want to add headers or body, use the ``HTTPClientRequest`` struct: + +#### Using Swift Concurrency + +```swift +import AsyncHTTPClient + +do { + var request = HTTPClientRequest(url: "https://apple.com/") + request.method = .POST + request.headers.add(name: "User-Agent", value: "Swift HTTPClient") + request.body = .bytes(ByteBuffer(string: "some data")) + + let response = try await HTTPClient.shared.execute(request, timeout: .seconds(30)) + if response.status == .ok { + // handle response + } else { + // handle remote error + } +} catch { + // handle error +} +``` + +#### Using SwiftNIO EventLoopFuture + +```swift +import AsyncHTTPClient + +var request = try HTTPClient.Request(url: "https://apple.com/", method: .POST) +request.headers.add(name: "User-Agent", value: "Swift HTTPClient") +request.body = .string("some-body") + +HTTPClient.shared.execute(request: request).whenComplete { result in + switch result { + case .failure(let error): + // process error + case .success(let response): + if response.status == .ok { + // handle response + } else { + // handle remote error + } + } +} +``` + +#### Redirects following +Enable follow-redirects behavior using the client configuration: +```swift +let httpClient = HTTPClient(eventLoopGroupProvider: .singleton, + configuration: HTTPClient.Configuration(followRedirects: true)) +``` + +#### Timeouts +Timeouts (connect and read) can also be set using the client configuration: +```swift +let timeout = HTTPClient.Configuration.Timeout(connect: .seconds(1), read: .seconds(1)) +let httpClient = HTTPClient(eventLoopGroupProvider: .singleton, + configuration: HTTPClient.Configuration(timeout: timeout)) +``` +or on a per-request basis: +```swift +httpClient.execute(request: request, deadline: .now() + .milliseconds(1)) +``` + +#### Streaming +When dealing with larger amount of data, it's critical to stream the response body instead of aggregating in-memory. +The following example demonstrates how to count the number of bytes in a streaming response body: + +##### Using Swift Concurrency +```swift +do { + let request = HTTPClientRequest(url: "https://apple.com/") + let response = try await HTTPClient.shared.execute(request, timeout: .seconds(30)) + print("HTTP head", response) + + // if defined, the content-length headers announces the size of the body + let expectedBytes = response.headers.first(name: "content-length").flatMap(Int.init) + + var receivedBytes = 0 + // asynchronously iterates over all body fragments + // this loop will automatically propagate backpressure correctly + for try await buffer in response.body { + // for this example, we are just interested in the size of the fragment + receivedBytes += buffer.readableBytes + + if let expectedBytes = expectedBytes { + // if the body size is known, we calculate a progress indicator + let progress = Double(receivedBytes) / Double(expectedBytes) + print("progress: \(Int(progress * 100))%") + } + } + print("did receive \(receivedBytes) bytes") +} catch { + print("request failed:", error) +} +``` + +##### Using HTTPClientResponseDelegate and SwiftNIO EventLoopFuture + +```swift +import NIOCore +import NIOHTTP1 + +class CountingDelegate: HTTPClientResponseDelegate { + typealias Response = Int + + var count = 0 + + func didSendRequestHead(task: HTTPClient.Task, _ head: HTTPRequestHead) { + // this is executed right after request head was sent, called once + } + + func didSendRequestPart(task: HTTPClient.Task, _ part: IOData) { + // this is executed when request body part is sent, could be called zero or more times + } + + func didSendRequest(task: HTTPClient.Task) { + // this is executed when request is fully sent, called once + } + + func didReceiveHead( + task: HTTPClient.Task, + _ head: HTTPResponseHead + ) -> EventLoopFuture { + // this is executed when we receive HTTP response head part of the request + // (it contains response code and headers), called once in case backpressure + // is needed, all reads will be paused until returned future is resolved + return task.eventLoop.makeSucceededFuture(()) + } + + func didReceiveBodyPart( + task: HTTPClient.Task, + _ buffer: ByteBuffer + ) -> EventLoopFuture { + // this is executed when we receive parts of the response body, could be called zero or more times + count += buffer.readableBytes + // in case backpressure is needed, all reads will be paused until returned future is resolved + return task.eventLoop.makeSucceededFuture(()) + } + + func didFinishRequest(task: HTTPClient.Task) throws -> Int { + // this is called when the request is fully read, called once + // this is where you return a result or throw any errors you require to propagate to the client + return count + } + + func didReceiveError(task: HTTPClient.Task, _ error: Error) { + // this is called when we receive any network-related error, called once + } +} + +let request = try HTTPClient.Request(url: "https://apple.com/") +let delegate = CountingDelegate() + +httpClient.execute(request: request, delegate: delegate).futureResult.whenSuccess { count in + print(count) +} +``` + +#### File downloads + +Based on the `HTTPClientResponseDelegate` example above you can build more complex delegates, +the built-in `FileDownloadDelegate` is one of them. It allows streaming the downloaded data +asynchronously, while reporting the download progress at the same time, like in the following +example: + +```swift +let request = try HTTPClient.Request( + url: "https://swift.org/builds/development/ubuntu1804/latest-build.yml" +) + +let delegate = try FileDownloadDelegate(path: "/tmp/latest-build.yml", reportProgress: { + if let totalBytes = $0.totalBytes { + print("Total bytes count: \(totalBytes)") + } + print("Downloaded \($0.receivedBytes) bytes so far") +}) + +HTTPClient.shared.execute(request: request, delegate: delegate).futureResult + .whenSuccess { progress in + if let totalBytes = progress.totalBytes { + print("Final total bytes count: \(totalBytes)") + } + print("Downloaded finished with \(progress.receivedBytes) bytes downloaded") + } +``` + +#### Unix Domain Socket Paths +Connecting to servers bound to socket paths is easy: +```swift +HTTPClient.shared.execute( + .GET, + socketPath: "/tmp/myServer.socket", + urlPath: "/path/to/resource" +).whenComplete (...) +``` + +Connecting over TLS to a unix domain socket path is possible as well: +```swift +HTTPClient.shared.execute( + .POST, + secureSocketPath: "/tmp/myServer.socket", + urlPath: "/path/to/resource", + body: .string("hello") +).whenComplete (...) +``` + +Direct URLs can easily be constructed to be executed in other scenarios: +```swift +let socketPathBasedURL = URL( + httpURLWithSocketPath: "/tmp/myServer.socket", + uri: "/path/to/resource" +) +let secureSocketPathBasedURL = URL( + httpsURLWithSocketPath: "/tmp/myServer.socket", + uri: "/path/to/resource" +) +``` + +#### Disabling HTTP/2 +The exclusive use of HTTP/1 is possible by setting ``HTTPClient/Configuration/httpVersion-swift.property`` to ``HTTPClient/Configuration/HTTPVersion-swift.struct/http1Only`` on the ``HTTPClient/Configuration``: +```swift +var configuration = HTTPClient.Configuration() +configuration.httpVersion = .http1Only +let client = HTTPClient( + eventLoopGroupProvider: .singleton, + configuration: configuration +) +``` + +### Security + +AsyncHTTPClient's security process is documented on [GitHub](https://github.com/swift-server/async-http-client/blob/main/SECURITY.md). + +## Topics + +### HTTPClient + +- ``HTTPClient`` +- ``HTTPClientRequest`` +- ``HTTPClientResponse`` + +### HTTP Client Delegates + +- ``HTTPClientResponseDelegate`` +- ``ResponseAccumulator`` +- ``FileDownloadDelegate`` +- ``HTTPClientCopyingDelegate`` + +### Errors + +- ``HTTPClientError`` diff --git a/Sources/AsyncHTTPClient/FileDownloadDelegate.swift b/Sources/AsyncHTTPClient/FileDownloadDelegate.swift index 6f046dce9..33a4d3cb2 100644 --- a/Sources/AsyncHTTPClient/FileDownloadDelegate.swift +++ b/Sources/AsyncHTTPClient/FileDownloadDelegate.swift @@ -12,65 +12,179 @@ // //===----------------------------------------------------------------------===// +import NIOConcurrencyHelpers import NIOCore import NIOHTTP1 import NIOPosix +import struct Foundation.URL + /// Handles a streaming download to a given file path, allowing headers and progress to be reported. public final class FileDownloadDelegate: HTTPClientResponseDelegate { /// The response type for this delegate: the total count of bytes as reported by the response - /// "Content-Length" header (if available) and the count of bytes downloaded. - public struct Progress { + /// "Content-Length" header (if available), the count of bytes downloaded, the + /// response head, and a history of requests and responses. + public struct Progress: Sendable { public var totalBytes: Int? public var receivedBytes: Int + + /// The history of all requests and responses in redirect order. + public var history: [HTTPClient.RequestResponse] = [] + + /// The target URL (after redirects) of the response. + public var url: URL? { + self.history.last?.request.url + } + + public var head: HTTPResponseHead { + get { + assert(self._head != nil) + return self._head! + } + set { + self._head = newValue + } + } + + fileprivate var _head: HTTPResponseHead? = nil + + internal init(totalBytes: Int? = nil, receivedBytes: Int) { + self.totalBytes = totalBytes + self.receivedBytes = receivedBytes + } } - private var progress = Progress(totalBytes: nil, receivedBytes: 0) + private struct State { + var progress = Progress( + totalBytes: nil, + receivedBytes: 0 + ) + var fileIOThreadPool: NIOThreadPool? + var fileHandleFuture: EventLoopFuture? + var writeFuture: EventLoopFuture? + } + private let state: NIOLockedValueBox + + var _fileIOThreadPool: NIOThreadPool? { + self.state.withLockedValue { $0.fileIOThreadPool } + } public typealias Response = Progress private let filePath: String - private let io: NonBlockingFileIO - private let reportHead: ((HTTPResponseHead) -> Void)? - private let reportProgress: ((Progress) -> Void)? - - private var fileHandleFuture: EventLoopFuture? - private var writeFuture: EventLoopFuture? + private let reportHead: (@Sendable (HTTPClient.Task, HTTPResponseHead) -> Void)? + private let reportProgress: (@Sendable (HTTPClient.Task, Progress) -> Void)? /// Initializes a new file download delegate. + /// /// - parameters: /// - path: Path to a file you'd like to write the download to. - /// - pool: A thread pool to use for asynchronous file I/O. + /// - pool: A thread pool to use for asynchronous file I/O. If nil, a shared thread pool will be used. Defaults to nil. /// - reportHead: A closure called when the response head is available. /// - reportProgress: A closure called when a body chunk has been downloaded, with /// the total byte count and download byte count passed to it as arguments. The callbacks /// will be invoked in the same threading context that the delegate itself is invoked, /// as controlled by `EventLoopPreference`. + @preconcurrency public init( path: String, - pool: NIOThreadPool = NIOThreadPool(numberOfThreads: 1), - reportHead: ((HTTPResponseHead) -> Void)? = nil, - reportProgress: ((Progress) -> Void)? = nil + pool: NIOThreadPool? = nil, + reportHead: (@Sendable (HTTPClient.Task, HTTPResponseHead) -> Void)? = nil, + reportProgress: (@Sendable (HTTPClient.Task, Progress) -> Void)? = nil ) throws { - pool.start() - self.io = NonBlockingFileIO(threadPool: pool) + self.state = NIOLockedValueBox(State(fileIOThreadPool: pool)) self.filePath = path self.reportHead = reportHead self.reportProgress = reportProgress } + /// Initializes a new file download delegate. + /// + /// - parameters: + /// - path: Path to a file you'd like to write the download to. + /// - pool: A thread pool to use for asynchronous file I/O. + /// - reportHead: A closure called when the response head is available. + /// - reportProgress: A closure called when a body chunk has been downloaded, with + /// the total byte count and download byte count passed to it as arguments. The callbacks + /// will be invoked in the same threading context that the delegate itself is invoked, + /// as controlled by `EventLoopPreference`. + @preconcurrency + public convenience init( + path: String, + pool: NIOThreadPool, + reportHead: (@Sendable (HTTPResponseHead) -> Void)? = nil, + reportProgress: (@Sendable (Progress) -> Void)? = nil + ) throws { + try self.init( + path: path, + pool: .some(pool), + reportHead: reportHead.map { reportHead in + { @Sendable _, head in + reportHead(head) + } + }, + reportProgress: reportProgress.map { reportProgress in + { @Sendable _, head in + reportProgress(head) + } + } + ) + } + + /// Initializes a new file download delegate and uses the shared thread pool of the ``HTTPClient`` for file I/O. + /// + /// - parameters: + /// - path: Path to a file you'd like to write the download to. + /// - reportHead: A closure called when the response head is available. + /// - reportProgress: A closure called when a body chunk has been downloaded, with + /// the total byte count and download byte count passed to it as arguments. The callbacks + /// will be invoked in the same threading context that the delegate itself is invoked, + /// as controlled by `EventLoopPreference`. + @preconcurrency + public convenience init( + path: String, + reportHead: (@Sendable (HTTPResponseHead) -> Void)? = nil, + reportProgress: (@Sendable (Progress) -> Void)? = nil + ) throws { + try self.init( + path: path, + pool: nil, + reportHead: reportHead.map { reportHead in + { @Sendable _, head in + reportHead(head) + } + }, + reportProgress: reportProgress.map { reportProgress in + { @Sendable _, head in + reportProgress(head) + } + } + ) + } + + public func didVisitURL(task: HTTPClient.Task, _ request: HTTPClient.Request, _ head: HTTPResponseHead) { + self.state.withLockedValue { + $0.progress.history.append(.init(request: request, responseHead: head)) + } + } + public func didReceiveHead( task: HTTPClient.Task, _ head: HTTPResponseHead ) -> EventLoopFuture { - self.reportHead?(head) + self.state.withLockedValue { + $0.progress._head = head - if let totalBytesString = head.headers.first(name: "Content-Length"), - let totalBytes = Int(totalBytesString) { - self.progress.totalBytes = totalBytes + if let totalBytesString = head.headers.first(name: "Content-Length"), + let totalBytes = Int(totalBytesString) + { + $0.progress.totalBytes = totalBytes + } } + self.reportHead?(task, head) + return task.eventLoop.makeSucceededFuture(()) } @@ -78,44 +192,90 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate { task: HTTPClient.Task, _ buffer: ByteBuffer ) -> EventLoopFuture { - self.progress.receivedBytes += buffer.readableBytes - self.reportProgress?(self.progress) + let (progress, io) = self.state.withLockedValue { state in + let threadPool: NIOThreadPool = { + guard let pool = state.fileIOThreadPool else { + let pool = task.fileIOThreadPool + state.fileIOThreadPool = pool + return pool + } + return pool + }() - let writeFuture: EventLoopFuture - if let fileHandleFuture = self.fileHandleFuture { - writeFuture = fileHandleFuture.flatMap { - self.io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop) - } - } else { - let fileHandleFuture = self.io.openFile( - path: self.filePath, - mode: .write, - flags: .allowFileCreation(), - eventLoop: task.eventLoop - ) - self.fileHandleFuture = fileHandleFuture - writeFuture = fileHandleFuture.flatMap { - self.io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop) + let io = NonBlockingFileIO(threadPool: threadPool) + state.progress.receivedBytes += buffer.readableBytes + return (state.progress, io) + } + self.reportProgress?(task, progress) + + let writeFuture = self.state.withLockedValue { state in + let writeFuture: EventLoopFuture + if let fileHandleFuture = state.fileHandleFuture { + writeFuture = fileHandleFuture.flatMap { + io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop) + } + } else { + let fileHandleFuture = io.openFile( + _deprecatedPath: self.filePath, + mode: .write, + flags: .allowFileCreation(), + eventLoop: task.eventLoop + ) + state.fileHandleFuture = fileHandleFuture + writeFuture = fileHandleFuture.flatMap { + io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop) + } } + + state.writeFuture = writeFuture + return writeFuture } - self.writeFuture = writeFuture return writeFuture } private func close(fileHandle: NIOFileHandle) { try! fileHandle.close() - self.fileHandleFuture = nil + self.state.withLockedValue { + $0.fileHandleFuture = nil + } } private func finalize() { - if let writeFuture = self.writeFuture { - writeFuture.whenComplete { _ in - self.fileHandleFuture?.whenSuccess(self.close(fileHandle:)) - self.writeFuture = nil + enum Finalize { + case writeFuture(EventLoopFuture) + case fileHandleFuture(EventLoopFuture) + case none + } + + let finalize: Finalize = self.state.withLockedValue { state in + if let writeFuture = state.writeFuture { + return .writeFuture(writeFuture) + } else if let fileHandleFuture = state.fileHandleFuture { + return .fileHandleFuture(fileHandleFuture) + } else { + return .none + } + } + + switch finalize { + case .writeFuture(let future): + future.whenComplete { _ in + let fileHandleFuture = self.state.withLockedValue { state in + let future = state.fileHandleFuture + state.fileHandleFuture = nil + state.writeFuture = nil + return future + } + + fileHandleFuture?.whenSuccess { + self.close(fileHandle: $0) + } } - } else { - self.fileHandleFuture?.whenSuccess(self.close(fileHandle:)) + case .fileHandleFuture(let future): + future.whenSuccess { self.close(fileHandle: $0) } + case .none: + () } } @@ -125,6 +285,6 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate { public func didFinishRequest(task: HTTPClient.Task) throws -> Response { self.finalize() - return self.progress + return self.state.withLockedValue { $0.progress } } } diff --git a/Sources/AsyncHTTPClient/FoundationExtensions.swift b/Sources/AsyncHTTPClient/FoundationExtensions.swift index 545da756b..452cb7b13 100644 --- a/Sources/AsyncHTTPClient/FoundationExtensions.swift +++ b/Sources/AsyncHTTPClient/FoundationExtensions.swift @@ -39,7 +39,16 @@ extension HTTPClient.Cookie { /// - maxAge: The cookie's age in seconds, defaults to nil. /// - httpOnly: Whether this cookie should be used by HTTP servers only, defaults to false. /// - secure: Whether this cookie should only be sent using secure channels, defaults to false. - public init(name: String, value: String, path: String = "/", domain: String? = nil, expires: Date? = nil, maxAge: Int? = nil, httpOnly: Bool = false, secure: Bool = false) { + public init( + name: String, + value: String, + path: String = "/", + domain: String? = nil, + expires: Date? = nil, + maxAge: Int? = nil, + httpOnly: Bool = false, + secure: Bool = false + ) { // FIXME: This should be failable and validate the inputs // (for example, checking that the strings are ASCII, path begins with "/", domain is not empty, etc). self.init( @@ -59,8 +68,8 @@ extension HTTPClient.Body { /// Create and stream body using `Data`. /// /// - parameters: - /// - bytes: Body `Data` representation. + /// - data: Body `Data` representation. public static func data(_ data: Data) -> HTTPClient.Body { - return self.bytes(data) + self.bytes(data) } } diff --git a/Sources/AsyncHTTPClient/HTTPClient+HTTPCookie.swift b/Sources/AsyncHTTPClient/HTTPClient+HTTPCookie.swift index 75fc28de4..759f6728a 100644 --- a/Sources/AsyncHTTPClient/HTTPClient+HTTPCookie.swift +++ b/Sources/AsyncHTTPClient/HTTPClient+HTTPCookie.swift @@ -12,17 +12,29 @@ // //===----------------------------------------------------------------------===// +import CAsyncHTTPClient +import NIOCore import NIOHTTP1 + +#if canImport(xlocale) +import xlocale +#elseif canImport(locale_h) +import locale_h +#endif + #if canImport(Darwin) import Darwin +#elseif canImport(Musl) +import Musl +#elseif canImport(Android) +import Android #elseif canImport(Glibc) import Glibc #endif -import CAsyncHTTPClient extension HTTPClient { /// A representation of an HTTP cookie. - public struct Cookie { + public struct Cookie: Sendable { /// The name of the cookie. public var name: String /// The cookie's string value. @@ -45,7 +57,6 @@ extension HTTPClient { /// - parameters: /// - header: String representation of the `Set-Cookie` response header. /// - defaultDomain: Default domain to use if cookie was sent without one. - /// - returns: nil if the header is invalid. public init?(header: String, defaultDomain: String) { // The parsing of "Set-Cookie" headers is defined by Section 5.2, RFC-6265: // https://datatracker.ietf.org/doc/html/rfc6265#section-5.2 @@ -126,7 +137,16 @@ extension HTTPClient { /// - maxAge: The cookie's age in seconds, defaults to nil. /// - httpOnly: Whether this cookie should be used by HTTP servers only, defaults to false. /// - secure: Whether this cookie should only be sent using secure channels, defaults to false. - internal init(name: String, value: String, path: String = "/", domain: String? = nil, expires_timestamp: Int64? = nil, maxAge: Int? = nil, httpOnly: Bool = false, secure: Bool = false) { + internal init( + name: String, + value: String, + path: String = "/", + domain: String? = nil, + expires_timestamp: Int64? = nil, + maxAge: Int? = nil, + httpOnly: Bool = false, + secure: Bool = false + ) { self.name = name self.value = value self.path = path @@ -142,7 +162,7 @@ extension HTTPClient { extension HTTPClient.Response { /// List of HTTP cookies returned by the server. public var cookies: [HTTPClient.Cookie] { - return self.headers["set-cookie"].compactMap { HTTPClient.Cookie(header: $0, defaultDomain: self.host) } + self.headers["set-cookie"].compactMap { HTTPClient.Cookie(header: $0, defaultDomain: self.host) } } } @@ -196,7 +216,7 @@ extension String.UTF8View.SubSequence { } } -private let posixLocale: UnsafeMutableRawPointer = { +nonisolated(unsafe) private let posixLocale: UnsafeMutableRawPointer = { // All POSIX systems must provide a "POSIX" locale, and its date/time formats are US English. // https://pubs.opengroup.org/onlinepubs/9699919799/basedefs/V1_chap07.html#tag_07_03_05 let _posixLocale = newlocale(LC_TIME_MASK | LC_NUMERIC_MASK, "POSIX", nil)! @@ -212,7 +232,8 @@ private func parseTimestamp(_ utf8: String.UTF8View.SubSequence, format: String) } private func parseCookieTime(_ timestampUTF8: String.UTF8View.SubSequence) -> Int64? { - if timestampUTF8.contains(where: { $0 < 0x20 /* Control characters */ || $0 == 0x7F /* DEL */ }) { + // 0x20: Control characters or 0x7F: DEL + if timestampUTF8.contains(where: { $0 < 0x20 || $0 == 0x7F }) { return nil } var timestampUTF8 = timestampUTF8 @@ -225,8 +246,8 @@ private func parseCookieTime(_ timestampUTF8: String.UTF8View.SubSequence) -> In } guard var timeComponents = parseTimestamp(timestampUTF8, format: "%a, %d %b %Y %H:%M:%S") - ?? parseTimestamp(timestampUTF8, format: "%a, %d-%b-%y %H:%M:%S") - ?? parseTimestamp(timestampUTF8, format: "%a %b %d %H:%M:%S %Y") + ?? parseTimestamp(timestampUTF8, format: "%a, %d-%b-%y %H:%M:%S") + ?? parseTimestamp(timestampUTF8, format: "%a %b %d %H:%M:%S %Y") else { return nil } diff --git a/Sources/AsyncHTTPClient/HTTPClient+Proxy.swift b/Sources/AsyncHTTPClient/HTTPClient+Proxy.swift index 4d2b9388f..e95c828ce 100644 --- a/Sources/AsyncHTTPClient/HTTPClient+Proxy.swift +++ b/Sources/AsyncHTTPClient/HTTPClient+Proxy.swift @@ -12,6 +12,8 @@ // //===----------------------------------------------------------------------===// +import NIOCore + extension HTTPClient.Configuration { /// Proxy server configuration /// Specifies the remote address of an HTTP proxy. @@ -23,7 +25,7 @@ extension HTTPClient.Configuration { /// If a `TLSConfiguration` is used in conjunction with `HTTPClient.Configuration.Proxy`, /// TLS will be established _after_ successful proxy, between your client /// and the destination server. - public struct Proxy { + public struct Proxy: Sendable, Hashable { enum ProxyType: Hashable { case http(HTTPClient.Authorization?) case socks @@ -36,7 +38,10 @@ extension HTTPClient.Configuration { /// Specifies Proxy server authorization. public var authorization: HTTPClient.Authorization? { set { - precondition(self.type == .http(self.authorization), "SOCKS authorization support is not yet implemented.") + precondition( + self.type == .http(self.authorization), + "SOCKS authorization support is not yet implemented." + ) self.type = .http(newValue) } @@ -58,7 +63,7 @@ extension HTTPClient.Configuration { /// - host: proxy server host. /// - port: proxy server port. public static func server(host: String, port: Int) -> Proxy { - return .init(host: host, port: port, type: .http(nil)) + .init(host: host, port: port, type: .http(nil)) } /// Create a HTTP proxy. @@ -68,7 +73,7 @@ extension HTTPClient.Configuration { /// - port: proxy server port. /// - authorization: proxy server authorization. public static func server(host: String, port: Int, authorization: HTTPClient.Authorization? = nil) -> Proxy { - return .init(host: host, port: port, type: .http(authorization)) + .init(host: host, port: port, type: .http(authorization)) } /// Create a SOCKSv5 proxy. @@ -76,7 +81,7 @@ extension HTTPClient.Configuration { /// - parameter port: The SOCKSv5 proxy port, defaults to 1080. /// - returns: A new instance of `Proxy` configured to connect to a `SOCKSv5` server. public static func socksServer(host: String, port: Int = 1080) -> Proxy { - return .init(host: host, port: port, type: .socks) + .init(host: host, port: port, type: .socks) } } } diff --git a/Sources/AsyncHTTPClient/HTTPClient+StructuredConcurrency.swift b/Sources/AsyncHTTPClient/HTTPClient+StructuredConcurrency.swift new file mode 100644 index 000000000..98c6555da --- /dev/null +++ b/Sources/AsyncHTTPClient/HTTPClient+StructuredConcurrency.swift @@ -0,0 +1,45 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2025 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Logging +import NIO + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension HTTPClient { + /// Start & automatically shut down a new ``HTTPClient``. + /// + /// This method allows to start & automatically dispose of a ``HTTPClient`` following the principle of Structured Concurrency. + /// The ``HTTPClient`` is guaranteed to be shut down upon return, whether `body` throws or not. + /// + /// This may be particularly useful if you cannot use the shared singleton (``HTTPClient/shared``). + public static func withHTTPClient( + eventLoopGroup: any EventLoopGroup = HTTPClient.defaultEventLoopGroup, + configuration: Configuration = Configuration(), + backgroundActivityLogger: Logger? = nil, + isolation: isolated (any Actor)? = #isolation, + _ body: (HTTPClient) async throws -> Return + ) async throws -> Return { + let logger = (backgroundActivityLogger ?? HTTPClient.loggingDisabled) + let httpClient = HTTPClient( + eventLoopGroup: eventLoopGroup, + configuration: configuration, + backgroundActivityLogger: logger + ) + return try await asyncDo { + try await body(httpClient) + } finally: { _ in + try await httpClient.shutdown() + } + } +} diff --git a/Sources/AsyncHTTPClient/HTTPClient.swift b/Sources/AsyncHTTPClient/HTTPClient.swift index 9301094ef..80df3b946 100644 --- a/Sources/AsyncHTTPClient/HTTPClient.swift +++ b/Sources/AsyncHTTPClient/HTTPClient.swift @@ -12,6 +12,7 @@ // //===----------------------------------------------------------------------===// +import Atomics import Foundation import Logging import NIOConcurrencyHelpers @@ -22,10 +23,11 @@ import NIOPosix import NIOSSL import NIOTLS import NIOTransportServices +import Tracing extension Logger { private func requestInfo(_ request: HTTPClient.Request) -> Logger.Metadata.Value { - return "\(request.method) \(request.url)" + "\(request.method) \(request.url)" } func attachingRequestInformation(_ request: HTTPClient.Request, requestID: Int) -> Logger { @@ -36,15 +38,14 @@ extension Logger { } } -let globalRequestID = NIOAtomic.makeAtomic(value: 0) +let globalRequestID = ManagedAtomic(0) /// HTTPClient class provides API for request execution. /// /// Example: /// /// ```swift -/// let client = HTTPClient(eventLoopGroupProvider: .createNew) -/// client.get(url: "https://swift.org", deadline: .now() + .seconds(1)).whenComplete { result in +/// HTTPClient.shared.get(url: "https://swift.org", deadline: .now() + .seconds(1)).whenComplete { result in /// switch result { /// case .failure(let error): /// // process error @@ -57,87 +58,162 @@ let globalRequestID = NIOAtomic.makeAtomic(value: 0) /// } /// } /// ``` -/// -/// It is important to close the client instance, for example in a defer statement, after use to cleanly shutdown the underlying NIO `EventLoopGroup`: -/// -/// ```swift -/// try client.syncShutdown() -/// ``` -public class HTTPClient { +public final class HTTPClient: Sendable { + /// The `EventLoopGroup` in use by this ``HTTPClient``. + /// + /// All HTTP transactions will occur on loops owned by this group. public let eventLoopGroup: EventLoopGroup - let eventLoopGroupProvider: EventLoopGroupProvider - let configuration: Configuration let poolManager: HTTPConnectionPool.Manager - private var state: State - private let stateLock = Lock() - internal static let loggingDisabled = Logger(label: "AHC-do-not-log", factory: { _ in SwiftLogNoOpLogHandler() }) + @usableFromInline + let configuration: Configuration + + /// Shared thread pool used for file IO. It is lazily created on first access of ``Task/fileIOThreadPool``. + private let fileIOThreadPool: NIOLockedValueBox + + private let state: NIOLockedValueBox + private let canBeShutDown: Bool - /// Create an `HTTPClient` with specified `EventLoopGroup` provider and configuration. + /// Tracer configured for this HTTPClient at configuration time. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public var tracer: (any Tracer)? { + configuration.tracing.tracer + } + + /// Access to tracing configuration in order to get configured attribute keys etc. + @usableFromInline + package var tracing: TracingConfiguration { + self.configuration.tracing + } + + static let loggingDisabled = Logger(label: "AHC-do-not-log", factory: { _ in SwiftLogNoOpLogHandler() }) + + /// Create an ``HTTPClient`` with specified `EventLoopGroup` provider and configuration. /// /// - parameters: /// - eventLoopGroupProvider: Specify how `EventLoopGroup` will be created. /// - configuration: Client configuration. - public convenience init(eventLoopGroupProvider: EventLoopGroupProvider, - configuration: Configuration = Configuration()) { - self.init(eventLoopGroupProvider: eventLoopGroupProvider, - configuration: configuration, - backgroundActivityLogger: HTTPClient.loggingDisabled) + public convenience init( + eventLoopGroupProvider: EventLoopGroupProvider, + configuration: Configuration = Configuration() + ) { + self.init( + eventLoopGroupProvider: eventLoopGroupProvider, + configuration: configuration, + backgroundActivityLogger: HTTPClient.loggingDisabled + ) + } + + /// Create an ``HTTPClient`` with specified `EventLoopGroup` and configuration. + /// + /// - parameters: + /// - eventLoopGroup: Specify how `EventLoopGroup` will be created. + /// - configuration: Client configuration. + public convenience init( + eventLoopGroup: EventLoopGroup = HTTPClient.defaultEventLoopGroup, + configuration: Configuration = Configuration() + ) { + self.init( + eventLoopGroupProvider: .shared(eventLoopGroup), + configuration: configuration, + backgroundActivityLogger: HTTPClient.loggingDisabled + ) } - /// Create an `HTTPClient` with specified `EventLoopGroup` provider and configuration. + /// Create an ``HTTPClient`` with specified `EventLoopGroup` provider and configuration. /// /// - parameters: /// - eventLoopGroupProvider: Specify how `EventLoopGroup` will be created. /// - configuration: Client configuration. - public required init(eventLoopGroupProvider: EventLoopGroupProvider, - configuration: Configuration = Configuration(), - backgroundActivityLogger: Logger) { - self.eventLoopGroupProvider = eventLoopGroupProvider - switch self.eventLoopGroupProvider { + /// - backgroundActivityLogger: The logger to use for background activity logs. + public convenience init( + eventLoopGroupProvider: EventLoopGroupProvider, + configuration: Configuration = Configuration(), + backgroundActivityLogger: Logger + ) { + let eventLoopGroup: any EventLoopGroup + + switch eventLoopGroupProvider { case .shared(let group): - self.eventLoopGroup = group - case .createNew: - #if canImport(Network) - if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *) { - self.eventLoopGroup = NIOTSEventLoopGroup() - } else { - self.eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) - } - #else - self.eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) - #endif + eventLoopGroup = group + default: // handle `.createNew` without a deprecation warning + eventLoopGroup = HTTPClient.defaultEventLoopGroup } + + self.init( + eventLoopGroup: eventLoopGroup, + configuration: configuration, + backgroundActivityLogger: backgroundActivityLogger + ) + } + + /// Create an ``HTTPClient`` with specified `EventLoopGroup` and configuration. + /// + /// - parameters: + /// - eventLoopGroup: The `EventLoopGroup` that the ``HTTPClient`` will use. + /// - configuration: Client configuration. + /// - backgroundActivityLogger: The `Logger` that will be used to log background any activity that's not associated with a request. + public convenience init( + eventLoopGroup: any EventLoopGroup = HTTPClient.defaultEventLoopGroup, + configuration: Configuration = Configuration(), + backgroundActivityLogger: Logger + ) { + self.init( + eventLoopGroup: eventLoopGroup, + configuration: configuration, + backgroundActivityLogger: backgroundActivityLogger, + canBeShutDown: true + ) + } + + internal required init( + eventLoopGroup: EventLoopGroup, + configuration: Configuration = Configuration(), + backgroundActivityLogger: Logger, + canBeShutDown: Bool + ) { + self.canBeShutDown = canBeShutDown + self.eventLoopGroup = eventLoopGroup self.configuration = configuration self.poolManager = HTTPConnectionPool.Manager( eventLoopGroup: self.eventLoopGroup, configuration: self.configuration, backgroundActivityLogger: backgroundActivityLogger ) - self.state = .upAndRunning + self.state = NIOLockedValueBox(.upAndRunning) + self.fileIOThreadPool = NIOLockedValueBox(nil) } deinit { debugOnly { // We want to crash only in debug mode. - switch self.state { - case .shutDown: - break - case .shuttingDown: - preconditionFailure(""" - This state should be totally unreachable. While the HTTPClient is shutting down a \ - reference cycle should exist, that prevents it from deinit. - """) - case .upAndRunning: - preconditionFailure(""" - Client not shut down before the deinit. Please call client.syncShutdown() when no \ - longer needed. Otherwise memory will leak. - """) + self.state.withLockedValue { state in + switch state { + case .shutDown: + break + case .shuttingDown: + preconditionFailure( + """ + This state should be totally unreachable. While the HTTPClient is shutting down a \ + reference cycle should exist, that prevents it from deinit. + """ + ) + case .upAndRunning: + preconditionFailure( + """ + Client not shut down before the deinit. Please call client.shutdown() when no \ + longer needed. Otherwise memory will leak. + """ + ) + } } } } /// Shuts down the client and `EventLoopGroup` if it was created by the client. + /// + /// This method blocks the thread indefinitely, prefer using ``shutdown()-96ayw``. + @available(*, noasync, message: "syncShutdown() can block indefinitely, prefer shutdown()", renamed: "shutdown()") public func syncShutdown() throws { try self.syncShutdown(requiresCleanClose: false) } @@ -152,68 +228,101 @@ public class HTTPClient { /// throw the appropriate error if needed. For instance, if its internal connection pool has any non-released connections, /// this indicate shutdown was called too early before tasks were completed or explicitly canceled. /// In general, setting this parameter to `true` should make it easier and faster to catch related programming errors. - internal func syncShutdown(requiresCleanClose: Bool) throws { + func syncShutdown(requiresCleanClose: Bool) throws { if let eventLoop = MultiThreadedEventLoopGroup.currentEventLoop { - preconditionFailure(""" - BUG DETECTED: syncShutdown() must not be called when on an EventLoop. - Calling syncShutdown() on any EventLoop can lead to deadlocks. - Current eventLoop: \(eventLoop) - """) + preconditionFailure( + """ + BUG DETECTED: syncShutdown() must not be called when on an EventLoop. + Calling syncShutdown() on any EventLoop can lead to deadlocks. + Current eventLoop: \(eventLoop) + """ + ) } - let errorStorageLock = Lock() - var errorStorage: Error? - let continuation = DispatchWorkItem {} - self.shutdown(requiresCleanClose: requiresCleanClose, queue: DispatchQueue(label: "async-http-client.shutdown")) { error in - if let error = error { - errorStorageLock.withLock { - errorStorage = error + + final class ShutdownError: @unchecked Sendable { + // @unchecked because error is protected by lock. + + // Stores whether the shutdown has happened or not. + private let lock: ConditionLock + private var error: Error? + + init() { + self.error = nil + self.lock = ConditionLock(value: false) + } + + func didShutdown(_ error: (any Error)?) { + self.lock.lock(whenValue: false) + defer { + self.lock.unlock(withValue: true) } + self.error = error } - continuation.perform() - } - continuation.wait() - try errorStorageLock.withLock { - if let error = errorStorage { - throw error + + func blockUntilShutdown() -> (any Error)? { + self.lock.lock(whenValue: true) + defer { + self.lock.unlock(withValue: true) + } + return self.error } } + + let shutdownError = ShutdownError() + + self.shutdown( + requiresCleanClose: requiresCleanClose, + queue: DispatchQueue(label: "async-http-client.shutdown") + ) { error in + shutdownError.didShutdown(error) + } + + let error = shutdownError.blockUntilShutdown() + + if let error = error { + throw error + } } - /// Shuts down the client and event loop gracefully. This function is clearly an outlier in that it uses a completion + /// Shuts down the client and event loop gracefully. + /// + /// This function is clearly an outlier in that it uses a completion /// callback instead of an EventLoopFuture. The reason for that is that NIO's EventLoopFutures will call back on an event loop. /// The virtue of this function is to shut the event loop down. To work around that we call back on a DispatchQueue /// instead. - public func shutdown(queue: DispatchQueue = .global(), _ callback: @escaping (Error?) -> Void) { + @preconcurrency public func shutdown( + queue: DispatchQueue = .global(), + _ callback: @Sendable @escaping (Error?) -> Void + ) { self.shutdown(requiresCleanClose: false, queue: queue, callback) } - private func shutdownEventLoop(queue: DispatchQueue, _ callback: @escaping (Error?) -> Void) { - self.stateLock.withLock { - switch self.eventLoopGroupProvider { - case .shared: - self.state = .shutDown - queue.async { - callback(nil) - } - case .createNew: - switch self.state { - case .shuttingDown: - self.state = .shutDown - self.eventLoopGroup.shutdownGracefully(queue: queue, callback) - case .shutDown, .upAndRunning: - assertionFailure("The only valid state at this point is \(String(describing: State.shuttingDown))") - } + /// Shuts down the ``HTTPClient`` and releases its resources. + public func shutdown() -> EventLoopFuture { + let promise = self.eventLoopGroup.any().makePromise(of: Void.self) + self.shutdown(queue: .global()) { error in + if let error = error { + promise.fail(error) + } else { + promise.succeed(()) } } + return promise.futureResult } - private func shutdown(requiresCleanClose: Bool, queue: DispatchQueue, _ callback: @escaping (Error?) -> Void) { + private func shutdown(requiresCleanClose: Bool, queue: DispatchQueue, _ callback: @escaping ShutdownCallback) { + guard self.canBeShutDown else { + queue.async { + callback(HTTPClientError.shutdownUnsupported) + } + return + } do { - try self.stateLock.withLock { - guard case .upAndRunning = self.state else { + try self.state.withLockedValue { state in + guard case .upAndRunning = state else { throw HTTPClientError.alreadyShutdown } - self.state = .shuttingDown(requiresCleanClose: requiresCleanClose, callback: callback) + state = .shuttingDown(requiresCleanClose: requiresCleanClose, callback: callback) } } catch { callback(error) @@ -227,30 +336,40 @@ public class HTTPClient { case .failure: preconditionFailure("Shutting down the connection pool must not fail, ever.") case .success(let unclean): - let (callback, uncleanError) = self.stateLock.withLock { () -> ((Error?) -> Void, Error?) in - guard case .shuttingDown(let requiresClean, callback: let callback) = self.state else { + let (callback, uncleanError) = self.state.withLockedValue { + (state: inout HTTPClient.State) -> (ShutdownCallback, Error?) in + guard case .shuttingDown(let requiresClean, callback: let callback) = state else { preconditionFailure("Why did the pool manager shut down, if it was not instructed to") } let error: Error? = (requiresClean && unclean) ? HTTPClientError.uncleanShutdown : nil + state = .shutDown return (callback, error) } - - self.shutdownEventLoop(queue: queue) { error in - let reportedError = error ?? uncleanError - callback(reportedError) + queue.async { + callback(uncleanError) } } } } + @Sendable + private func makeOrGetFileIOThreadPool() -> NIOThreadPool { + self.fileIOThreadPool.withLockedValue { pool in + guard let pool else { + return NIOThreadPool.singleton + } + return pool + } + } + /// Execute `GET` request using specified URL. /// /// - parameters: /// - url: Remote URL. /// - deadline: Point in time by which the request must complete. public func get(url: String, deadline: NIODeadline? = nil) -> EventLoopFuture { - return self.get(url: url, deadline: deadline, logger: HTTPClient.loggingDisabled) + self.get(url: url, deadline: deadline, logger: HTTPClient.loggingDisabled) } /// Execute `GET` request using specified URL. @@ -260,7 +379,7 @@ public class HTTPClient { /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. public func get(url: String, deadline: NIODeadline? = nil, logger: Logger) -> EventLoopFuture { - return self.execute(.GET, url: url, deadline: deadline, logger: logger) + self.execute(.GET, url: url, deadline: deadline, logger: logger) } /// Execute `POST` request using specified URL. @@ -270,7 +389,7 @@ public class HTTPClient { /// - body: Request body. /// - deadline: Point in time by which the request must complete. public func post(url: String, body: Body? = nil, deadline: NIODeadline? = nil) -> EventLoopFuture { - return self.post(url: url, body: body, deadline: deadline, logger: HTTPClient.loggingDisabled) + self.post(url: url, body: body, deadline: deadline, logger: HTTPClient.loggingDisabled) } /// Execute `POST` request using specified URL. @@ -280,8 +399,13 @@ public class HTTPClient { /// - body: Request body. /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. - public func post(url: String, body: Body? = nil, deadline: NIODeadline? = nil, logger: Logger) -> EventLoopFuture { - return self.execute(.POST, url: url, body: body, deadline: deadline, logger: logger) + public func post( + url: String, + body: Body? = nil, + deadline: NIODeadline? = nil, + logger: Logger + ) -> EventLoopFuture { + self.execute(.POST, url: url, body: body, deadline: deadline, logger: logger) } /// Execute `PATCH` request using specified URL. @@ -291,7 +415,7 @@ public class HTTPClient { /// - body: Request body. /// - deadline: Point in time by which the request must complete. public func patch(url: String, body: Body? = nil, deadline: NIODeadline? = nil) -> EventLoopFuture { - return self.patch(url: url, body: body, deadline: deadline, logger: HTTPClient.loggingDisabled) + self.patch(url: url, body: body, deadline: deadline, logger: HTTPClient.loggingDisabled) } /// Execute `PATCH` request using specified URL. @@ -301,8 +425,13 @@ public class HTTPClient { /// - body: Request body. /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. - public func patch(url: String, body: Body? = nil, deadline: NIODeadline? = nil, logger: Logger) -> EventLoopFuture { - return self.execute(.PATCH, url: url, body: body, deadline: deadline, logger: logger) + public func patch( + url: String, + body: Body? = nil, + deadline: NIODeadline? = nil, + logger: Logger + ) -> EventLoopFuture { + self.execute(.PATCH, url: url, body: body, deadline: deadline, logger: logger) } /// Execute `PUT` request using specified URL. @@ -312,7 +441,7 @@ public class HTTPClient { /// - body: Request body. /// - deadline: Point in time by which the request must complete. public func put(url: String, body: Body? = nil, deadline: NIODeadline? = nil) -> EventLoopFuture { - return self.put(url: url, body: body, deadline: deadline, logger: HTTPClient.loggingDisabled) + self.put(url: url, body: body, deadline: deadline, logger: HTTPClient.loggingDisabled) } /// Execute `PUT` request using specified URL. @@ -322,8 +451,13 @@ public class HTTPClient { /// - body: Request body. /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. - public func put(url: String, body: Body? = nil, deadline: NIODeadline? = nil, logger: Logger) -> EventLoopFuture { - return self.execute(.PUT, url: url, body: body, deadline: deadline, logger: logger) + public func put( + url: String, + body: Body? = nil, + deadline: NIODeadline? = nil, + logger: Logger + ) -> EventLoopFuture { + self.execute(.PUT, url: url, body: body, deadline: deadline, logger: logger) } /// Execute `DELETE` request using specified URL. @@ -332,7 +466,7 @@ public class HTTPClient { /// - url: Remote URL. /// - deadline: The time when the request must have been completed by. public func delete(url: String, deadline: NIODeadline? = nil) -> EventLoopFuture { - return self.delete(url: url, deadline: deadline, logger: HTTPClient.loggingDisabled) + self.delete(url: url, deadline: deadline, logger: HTTPClient.loggingDisabled) } /// Execute `DELETE` request using specified URL. @@ -342,7 +476,7 @@ public class HTTPClient { /// - deadline: The time when the request must have been completed by. /// - logger: The logger to use for this request. public func delete(url: String, deadline: NIODeadline? = nil, logger: Logger) -> EventLoopFuture { - return self.execute(.DELETE, url: url, deadline: deadline, logger: logger) + self.execute(.DELETE, url: url, deadline: deadline, logger: logger) } /// Execute arbitrary HTTP request using specified URL. @@ -353,7 +487,13 @@ public class HTTPClient { /// - body: Request body. /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. - public func execute(_ method: HTTPMethod = .GET, url: String, body: Body? = nil, deadline: NIODeadline? = nil, logger: Logger? = nil) -> EventLoopFuture { + public func execute( + _ method: HTTPMethod = .GET, + url: String, + body: Body? = nil, + deadline: NIODeadline? = nil, + logger: Logger? = nil + ) -> EventLoopFuture { do { let request = try Request(url: url, method: method, body: body) return self.execute(request: request, deadline: deadline, logger: logger ?? HTTPClient.loggingDisabled) @@ -371,7 +511,14 @@ public class HTTPClient { /// - body: Request body. /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. - public func execute(_ method: HTTPMethod = .GET, socketPath: String, urlPath: String, body: Body? = nil, deadline: NIODeadline? = nil, logger: Logger? = nil) -> EventLoopFuture { + public func execute( + _ method: HTTPMethod = .GET, + socketPath: String, + urlPath: String, + body: Body? = nil, + deadline: NIODeadline? = nil, + logger: Logger? = nil + ) -> EventLoopFuture { do { guard let url = URL(httpURLWithSocketPath: socketPath, uri: urlPath) else { throw HTTPClientError.invalidURL @@ -392,7 +539,14 @@ public class HTTPClient { /// - body: Request body. /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. - public func execute(_ method: HTTPMethod = .GET, secureSocketPath: String, urlPath: String, body: Body? = nil, deadline: NIODeadline? = nil, logger: Logger? = nil) -> EventLoopFuture { + public func execute( + _ method: HTTPMethod = .GET, + secureSocketPath: String, + urlPath: String, + body: Body? = nil, + deadline: NIODeadline? = nil, + logger: Logger? = nil + ) -> EventLoopFuture { do { guard let url = URL(httpsURLWithSocketPath: secureSocketPath, uri: urlPath) else { throw HTTPClientError.invalidURL @@ -410,7 +564,7 @@ public class HTTPClient { /// - request: HTTP request to execute. /// - deadline: Point in time by which the request must complete. public func execute(request: Request, deadline: NIODeadline? = nil) -> EventLoopFuture { - return self.execute(request: request, deadline: deadline, logger: HTTPClient.loggingDisabled) + self.execute(request: request, deadline: deadline, logger: HTTPClient.loggingDisabled) } /// Execute arbitrary HTTP request using specified URL. @@ -430,26 +584,40 @@ public class HTTPClient { /// - request: HTTP request to execute. /// - eventLoop: NIO Event Loop preference. /// - deadline: Point in time by which the request must complete. - public func execute(request: Request, eventLoop: EventLoopPreference, deadline: NIODeadline? = nil) -> EventLoopFuture { - return self.execute(request: request, - eventLoop: eventLoop, - deadline: deadline, - logger: HTTPClient.loggingDisabled) + public func execute( + request: Request, + eventLoop: EventLoopPreference, + deadline: NIODeadline? = nil + ) -> EventLoopFuture { + self.execute( + request: request, + eventLoop: eventLoop, + deadline: deadline, + logger: HTTPClient.loggingDisabled + ) } /// Execute arbitrary HTTP request and handle response processing using provided delegate. /// /// - parameters: /// - request: HTTP request to execute. - /// - eventLoop: NIO Event Loop preference. + /// - eventLoopPreference: NIO Event Loop preference. /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. - public func execute(request: Request, - eventLoop eventLoopPreference: EventLoopPreference, - deadline: NIODeadline? = nil, - logger: Logger?) -> EventLoopFuture { + public func execute( + request: Request, + eventLoop eventLoopPreference: EventLoopPreference, + deadline: NIODeadline? = nil, + logger: Logger? + ) -> EventLoopFuture { let accumulator = ResponseAccumulator(request: request) - return self.execute(request: request, delegate: accumulator, eventLoop: eventLoopPreference, deadline: deadline, logger: logger).futureResult + return self.execute( + request: request, + delegate: accumulator, + eventLoop: eventLoopPreference, + deadline: deadline, + logger: logger + ).futureResult } /// Execute arbitrary HTTP request and handle response processing using provided delegate. @@ -458,10 +626,12 @@ public class HTTPClient { /// - request: HTTP request to execute. /// - delegate: Delegate to process response parts. /// - deadline: Point in time by which the request must complete. - public func execute(request: Request, - delegate: Delegate, - deadline: NIODeadline? = nil) -> Task { - return self.execute(request: request, delegate: delegate, deadline: deadline, logger: HTTPClient.loggingDisabled) + public func execute( + request: Request, + delegate: Delegate, + deadline: NIODeadline? = nil + ) -> Task { + self.execute(request: request, delegate: delegate, deadline: deadline, logger: HTTPClient.loggingDisabled) } /// Execute arbitrary HTTP request and handle response processing using provided delegate. @@ -471,11 +641,13 @@ public class HTTPClient { /// - delegate: Delegate to process response parts. /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. - public func execute(request: Request, - delegate: Delegate, - deadline: NIODeadline? = nil, - logger: Logger) -> Task { - return self.execute(request: request, delegate: delegate, eventLoop: .indifferent, deadline: deadline, logger: logger) + public func execute( + request: Request, + delegate: Delegate, + deadline: NIODeadline? = nil, + logger: Logger + ) -> Task { + self.execute(request: request, delegate: delegate, eventLoop: .indifferent, deadline: deadline, logger: logger) } /// Execute arbitrary HTTP request and handle response processing using provided delegate. @@ -483,18 +655,21 @@ public class HTTPClient { /// - parameters: /// - request: HTTP request to execute. /// - delegate: Delegate to process response parts. - /// - eventLoop: NIO Event Loop preference. + /// - eventLoopPreference: NIO Event Loop preference. /// - deadline: Point in time by which the request must complete. - /// - logger: The logger to use for this request. - public func execute(request: Request, - delegate: Delegate, - eventLoop eventLoopPreference: EventLoopPreference, - deadline: NIODeadline? = nil) -> Task { - return self.execute(request: request, - delegate: delegate, - eventLoop: eventLoopPreference, - deadline: deadline, - logger: HTTPClient.loggingDisabled) + public func execute( + request: Request, + delegate: Delegate, + eventLoop eventLoopPreference: EventLoopPreference, + deadline: NIODeadline? = nil + ) -> Task { + self.execute( + request: request, + delegate: delegate, + eventLoop: eventLoopPreference, + deadline: deadline, + logger: HTTPClient.loggingDisabled + ) } /// Execute arbitrary HTTP request and handle response processing using provided delegate. @@ -502,7 +677,7 @@ public class HTTPClient { /// - parameters: /// - request: HTTP request to execute. /// - delegate: Delegate to process response parts. - /// - eventLoop: NIO Event Loop preference. + /// - eventLoopPreference: NIO Event Loop preference. /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. public func execute( @@ -510,14 +685,14 @@ public class HTTPClient { delegate: Delegate, eventLoop eventLoopPreference: EventLoopPreference, deadline: NIODeadline? = nil, - logger originalLogger: Logger? + logger: Logger? ) -> Task { self._execute( request: request, delegate: delegate, eventLoop: eventLoopPreference, deadline: deadline, - logger: originalLogger, + logger: logger, redirectState: RedirectState( self.configuration.redirectConfiguration.mode, initialURL: request.url.absoluteString @@ -541,34 +716,53 @@ public class HTTPClient { logger originalLogger: Logger?, redirectState: RedirectState? ) -> Task { - let logger = (originalLogger ?? HTTPClient.loggingDisabled).attachingRequestInformation(request, requestID: globalRequestID.add(1)) + let logger = (originalLogger ?? HTTPClient.loggingDisabled).attachingRequestInformation( + request, + requestID: globalRequestID.wrappingIncrementThenLoad(ordering: .relaxed) + ) + let taskEL: EventLoop switch eventLoopPreference.preference { case .indifferent: // if possible we want a connection on the current `EventLoop` taskEL = self.eventLoopGroup.any() case .delegate(on: let eventLoop): - precondition(self.eventLoopGroup.makeIterator().contains { $0 === eventLoop }, "Provided EventLoop must be part of clients EventLoopGroup.") + precondition( + self.eventLoopGroup.makeIterator().contains { $0 === eventLoop }, + "Provided EventLoop must be part of clients EventLoopGroup." + ) taskEL = eventLoop case .delegateAndChannel(on: let eventLoop): - precondition(self.eventLoopGroup.makeIterator().contains { $0 === eventLoop }, "Provided EventLoop must be part of clients EventLoopGroup.") + precondition( + self.eventLoopGroup.makeIterator().contains { $0 === eventLoop }, + "Provided EventLoop must be part of clients EventLoopGroup." + ) taskEL = eventLoop case .testOnly_exact(_, delegateOn: let delegateEL): taskEL = delegateEL } - logger.trace("selected EventLoop for task given the preference", - metadata: ["ahc-eventloop": "\(taskEL)", - "ahc-el-preference": "\(eventLoopPreference)"]) - let failedTask: Task? = self.stateLock.withLock { + logger.trace( + "selected EventLoop for task given the preference", + metadata: [ + "ahc-eventloop": "\(taskEL)", + "ahc-el-preference": "\(eventLoopPreference)", + ] + ) + + let failedTask: Task? = self.state.withLockedValue { state -> (Task?) in switch state { case .upAndRunning: return nil case .shuttingDown, .shutDown: logger.debug("client is shutting down, failing request") - return Task.failedTask(eventLoop: taskEL, - error: HTTPClientError.alreadyShutdown, - logger: logger) + return Task.failedTask( + eventLoop: taskEL, + error: HTTPClientError.alreadyShutdown, + logger: logger, + tracing: tracing, + makeOrGetFileIOThreadPool: self.makeOrGetFileIOThreadPool + ) } } @@ -591,42 +785,60 @@ public class HTTPClient { } }() - let task = Task(eventLoop: taskEL, logger: logger) + let task: HTTPClient.Task = + Task( + eventLoop: taskEL, + logger: logger, + tracing: self.tracing, + makeOrGetFileIOThreadPool: self.makeOrGetFileIOThreadPool + ) + do { let requestBag = try RequestBag( request: request, eventLoopPreference: eventLoopPreference, task: task, redirectHandler: redirectHandler, - connectionDeadline: .now() + (self.configuration.timeout.connect ?? .seconds(10)), + connectionDeadline: .now() + (self.configuration.timeout.connectionCreationTimeout), requestOptions: .fromClientConfiguration(self.configuration), delegate: delegate ) - var deadlineSchedule: Scheduled? if let deadline = deadline { - deadlineSchedule = taskEL.scheduleTask(deadline: deadline) { - requestBag.fail(HTTPClientError.deadlineExceeded) + let deadlineSchedule = taskEL.scheduleTask(deadline: deadline) { + requestBag.deadlineExceeded() } task.promise.futureResult.whenComplete { _ in - deadlineSchedule?.cancel() + deadlineSchedule.cancel() } } self.poolManager.executeRequest(requestBag) } catch { - task.fail(with: error, delegateType: Delegate.self) + delegate.didReceiveError(task: task, error) + task.failInternal(with: error) } return task } - /// `HTTPClient` configuration. + /// ``HTTPClient`` configuration. public struct Configuration { /// TLS configuration, defaults to `TLSConfiguration.makeClientConfiguration()`. public var tlsConfiguration: Optional - /// Enables following 3xx redirects automatically, defaults to `RedirectConfiguration()`. + + /// Sometimes it can be useful to connect to one host e.g. `x.example.com` but + /// request and validate the certificate chain as if we would connect to `y.example.com`. + /// ``dnsOverride`` allows to do just that by mapping host names which we will request and validate the certificate chain, to a different + /// host name which will be used to actually connect to. + /// + /// **Example:** if ``dnsOverride`` is set to `["example.com": "localhost"]` and we execute a request with a + /// `url` of `https://example.com/`, the ``HTTPClient`` will actually open a connection to `localhost` instead of `example.com`. + /// ``HTTPClient`` will still request certificates from the server for `example.com` and validate them as if we would connect to `example.com`. + public var dnsOverride: [String: String] = [:] + + /// Enables following 3xx redirects automatically. /// /// Following redirects are supported: /// - `301: Moved Permanently` @@ -637,7 +849,8 @@ public class HTTPClient { /// - `307: Temporary Redirect` /// - `308: Permanent Redirect` public var redirectConfiguration: RedirectConfiguration - /// Default client timeout, defaults to no `read` timeout and 10 seconds `connect` timeout. + /// Default client timeout, defaults to no ``Timeout-swift.struct/read`` timeout + /// and 10 seconds ``Timeout-swift.struct/connect`` timeout. public var timeout: Timeout /// Connection pool configuration. public var connectionPool: ConnectionPool @@ -646,15 +859,54 @@ public class HTTPClient { /// Enables automatic body decompression. Supported algorithms are gzip and deflate. public var decompression: Decompression /// Ignore TLS unclean shutdown error, defaults to `false`. - @available(*, deprecated, message: "AsyncHTTPClient now correctly supports handling unexpected SSL connection drops. This property is ignored") + @available( + *, + deprecated, + message: + "AsyncHTTPClient now correctly supports handling unexpected SSL connection drops. This property is ignored" + ) public var ignoreUncleanSSLShutdown: Bool { get { false } set {} } - /// is set to `.automatic` by default which will use HTTP/2 if run over https and the server supports it, otherwise HTTP/1 + /// What HTTP versions to use. + /// + /// Set to ``HTTPVersion-swift.struct/automatic`` by default which will use HTTP/2 if run over https and the server supports it, otherwise HTTP/1 public var httpVersion: HTTPVersion + /// Whether ``HTTPClient`` will let Network.framework sit in the `.waiting` state awaiting new network changes, or fail immediately. Defaults to `true`, + /// which is the recommended setting. Only set this to `false` when attempting to trigger a particular error path. + public var networkFrameworkWaitForConnectivity: Bool + + /// The maximum number of times each connection can be used before it is replaced with a new one. Use `nil` (the default) + /// if no limit should be applied to each connection. + /// + /// - Precondition: The value must be greater than zero. + public var maximumUsesPerConnection: Int? { + willSet { + if let newValue = newValue, newValue <= 0 { + fatalError("maximumUsesPerConnection must be greater than zero or nil") + } + } + } + + /// Whether ``HTTPClient`` will use Multipath TCP or not + /// By default, don't use it + public var enableMultipath: Bool + + /// A method with access to the HTTP/1 connection channel that is called when creating the connection. + public var http1_1ConnectionDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? + + /// A method with access to the HTTP/2 connection channel that is called when creating the connection. + public var http2ConnectionDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? + + /// A method with access to the HTTP/2 stream channel that is called when creating the stream. + public var http2StreamChannelDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? + + /// Configuration how distributed traces are created and handled. + public var tracing: TracingConfiguration = .init() + public init( tlsConfiguration: TLSConfiguration? = nil, redirectConfiguration: RedirectConfiguration? = nil, @@ -671,14 +923,18 @@ public class HTTPClient { self.proxy = proxy self.decompression = decompression self.httpVersion = .automatic + self.networkFrameworkWaitForConnectivity = true + self.enableMultipath = false } - public init(tlsConfiguration: TLSConfiguration? = nil, - redirectConfiguration: RedirectConfiguration? = nil, - timeout: Timeout = Timeout(), - proxy: Proxy? = nil, - ignoreUncleanSSLShutdown: Bool = false, - decompression: Decompression = .disabled) { + public init( + tlsConfiguration: TLSConfiguration? = nil, + redirectConfiguration: RedirectConfiguration? = nil, + timeout: Timeout = Timeout(), + proxy: Proxy? = nil, + ignoreUncleanSSLShutdown: Bool = false, + decompression: Decompression = .disabled + ) { self.init( tlsConfiguration: tlsConfiguration, redirectConfiguration: redirectConfiguration, @@ -690,49 +946,59 @@ public class HTTPClient { ) } - public init(certificateVerification: CertificateVerification, - redirectConfiguration: RedirectConfiguration? = nil, - timeout: Timeout = Timeout(), - maximumAllowedIdleTimeInConnectionPool: TimeAmount = .seconds(60), - proxy: Proxy? = nil, - ignoreUncleanSSLShutdown: Bool = false, - decompression: Decompression = .disabled) { + public init( + certificateVerification: CertificateVerification, + redirectConfiguration: RedirectConfiguration? = nil, + timeout: Timeout = Timeout(), + maximumAllowedIdleTimeInConnectionPool: TimeAmount = .seconds(60), + proxy: Proxy? = nil, + ignoreUncleanSSLShutdown: Bool = false, + decompression: Decompression = .disabled + ) { var tlsConfig = TLSConfiguration.makeClientConfiguration() tlsConfig.certificateVerification = certificateVerification - self.init(tlsConfiguration: tlsConfig, - redirectConfiguration: redirectConfiguration, - timeout: timeout, - connectionPool: ConnectionPool(), - proxy: proxy, - ignoreUncleanSSLShutdown: ignoreUncleanSSLShutdown, - decompression: decompression) + self.init( + tlsConfiguration: tlsConfig, + redirectConfiguration: redirectConfiguration, + timeout: timeout, + connectionPool: ConnectionPool(idleTimeout: maximumAllowedIdleTimeInConnectionPool), + proxy: proxy, + ignoreUncleanSSLShutdown: ignoreUncleanSSLShutdown, + decompression: decompression + ) } - public init(certificateVerification: CertificateVerification, - redirectConfiguration: RedirectConfiguration? = nil, - timeout: Timeout = Timeout(), - connectionPool: TimeAmount = .seconds(60), - proxy: Proxy? = nil, - ignoreUncleanSSLShutdown: Bool = false, - decompression: Decompression = .disabled, - backgroundActivityLogger: Logger?) { + public init( + certificateVerification: CertificateVerification, + redirectConfiguration: RedirectConfiguration? = nil, + timeout: Timeout = Timeout(), + connectionPool: TimeAmount = .seconds(60), + proxy: Proxy? = nil, + ignoreUncleanSSLShutdown: Bool = false, + decompression: Decompression = .disabled, + backgroundActivityLogger: Logger? + ) { var tlsConfig = TLSConfiguration.makeClientConfiguration() tlsConfig.certificateVerification = certificateVerification - self.init(tlsConfiguration: tlsConfig, - redirectConfiguration: redirectConfiguration, - timeout: timeout, - connectionPool: ConnectionPool(), - proxy: proxy, - ignoreUncleanSSLShutdown: ignoreUncleanSSLShutdown, - decompression: decompression) + self.init( + tlsConfiguration: tlsConfig, + redirectConfiguration: redirectConfiguration, + timeout: timeout, + connectionPool: ConnectionPool(idleTimeout: connectionPool), + proxy: proxy, + ignoreUncleanSSLShutdown: ignoreUncleanSSLShutdown, + decompression: decompression + ) } - public init(certificateVerification: CertificateVerification, - redirectConfiguration: RedirectConfiguration? = nil, - timeout: Timeout = Timeout(), - proxy: Proxy? = nil, - ignoreUncleanSSLShutdown: Bool = false, - decompression: Decompression = .disabled) { + public init( + certificateVerification: CertificateVerification, + redirectConfiguration: RedirectConfiguration? = nil, + timeout: Timeout = Timeout(), + proxy: Proxy? = nil, + ignoreUncleanSSLShutdown: Bool = false, + decompression: Decompression = .disabled + ) { self.init( certificateVerification: certificateVerification, redirectConfiguration: redirectConfiguration, @@ -743,17 +1009,130 @@ public class HTTPClient { decompression: decompression ) } + + public init( + tlsConfiguration: TLSConfiguration? = nil, + redirectConfiguration: RedirectConfiguration? = nil, + timeout: Timeout = Timeout(), + connectionPool: ConnectionPool = ConnectionPool(), + proxy: Proxy? = nil, + ignoreUncleanSSLShutdown: Bool = false, + decompression: Decompression = .disabled, + http1_1ConnectionDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? = nil, + http2ConnectionDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? = nil, + http2StreamChannelDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? = nil + ) { + self.init( + tlsConfiguration: tlsConfiguration, + redirectConfiguration: redirectConfiguration, + timeout: timeout, + connectionPool: connectionPool, + proxy: proxy, + ignoreUncleanSSLShutdown: ignoreUncleanSSLShutdown, + decompression: decompression + ) + self.http1_1ConnectionDebugInitializer = http1_1ConnectionDebugInitializer + self.http2ConnectionDebugInitializer = http2ConnectionDebugInitializer + self.http2StreamChannelDebugInitializer = http2StreamChannelDebugInitializer + } + + public init( + tlsConfiguration: TLSConfiguration? = nil, + redirectConfiguration: RedirectConfiguration? = nil, + timeout: Timeout = Timeout(), + connectionPool: ConnectionPool = ConnectionPool(), + proxy: Proxy? = nil, + ignoreUncleanSSLShutdown: Bool = false, + decompression: Decompression = .disabled, + http1_1ConnectionDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? = nil, + http2ConnectionDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? = nil, + http2StreamChannelDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? = nil, + tracing: TracingConfiguration = .init() + ) { + self.init( + tlsConfiguration: tlsConfiguration, + redirectConfiguration: redirectConfiguration, + timeout: timeout, + connectionPool: connectionPool, + proxy: proxy, + ignoreUncleanSSLShutdown: ignoreUncleanSSLShutdown, + decompression: decompression + ) + self.http1_1ConnectionDebugInitializer = http1_1ConnectionDebugInitializer + self.http2ConnectionDebugInitializer = http2ConnectionDebugInitializer + self.http2StreamChannelDebugInitializer = http2StreamChannelDebugInitializer + self.tracing = tracing + } + } + + public struct TracingConfiguration: Sendable { + + @usableFromInline + var _tracer: Optional // erasure trick so we don't have to make Configuration @available + + /// Tracer that should be used by the HTTPClient. + /// + /// This is selected at configuration creation time, and if no tracer is passed explicitly, + /// (including `nil` in order to disable traces), the default global bootstrapped tracer will + /// be stored in this property, and used for all subsequent requests made by this client. + @inlinable + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public var tracer: (any Tracer)? { + get { + guard let _tracer else { + return nil + } + return _tracer as! (any Tracer)? + } + set { + self._tracer = newValue + } + } + + // TODO: Open up customization of keys we use? + /// Configuration for tracing attributes set by the HTTPClient. + @usableFromInline + package var attributeKeys: AttributeKeys + + public init() { + if #available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) { + self._tracer = InstrumentationSystem.tracer + } else { + self._tracer = nil + } + self.attributeKeys = .init() + } + + /// Span attribute keys that the HTTPClient should set automatically. + /// This struct allows the configuration of the attribute names (keys) which will be used for the apropriate values. + @usableFromInline + package struct AttributeKeys: Sendable { + @usableFromInline package var requestMethod: String = "http.request.method" + @usableFromInline package var requestBodySize: String = "http.request.body.size" + + @usableFromInline package var responseBodySize: String = "http.response.body.size" + @usableFromInline package var responseStatusCode: String = "http.status_code" + + @usableFromInline package var httpFlavor: String = "http.flavor" + + @usableFromInline package init() {} + } } /// Specifies how `EventLoopGroup` will be created and establishes lifecycle ownership. public enum EventLoopGroupProvider { /// `EventLoopGroup` will be provided by the user. Owner of this group is responsible for its lifecycle. case shared(EventLoopGroup) - /// `EventLoopGroup` will be created by the client. When `syncShutdown` is called, created `EventLoopGroup` will be shut down as well. + /// The original intention of this was that ``HTTPClient`` would create and own its own `EventLoopGroup` to + /// facilitate use in programs that are not already using SwiftNIO. + /// Since https://github.com/apple/swift-nio/pull/2471 however, SwiftNIO does provide a global, shared singleton + /// `EventLoopGroup`s that we can use. ``HTTPClient`` is no longer able to create & own its own + /// `EventLoopGroup` which solves a whole host of issues around shutdown. + @available(*, deprecated, renamed: "singleton", message: "Please use the singleton EventLoopGroup explicitly") case createNew } - /// Specifies how the library will treat event loop passed by the user. + /// Specifies how the library will treat the event loop passed by the user. public struct EventLoopPreference { enum Preference { /// Event Loop will be selected by the library. @@ -781,7 +1160,7 @@ public class HTTPClient { /// `EventLoop` but will not establish a new network connection just to satisfy the `EventLoop` preference if /// another existing connection on a different `EventLoop` is readily available from a connection pool. public static func delegate(on eventLoop: EventLoop) -> EventLoopPreference { - return EventLoopPreference(.delegate(on: eventLoop)) + EventLoopPreference(.delegate(on: eventLoop)) } /// The delegate and the `Channel` will be run on the specified EventLoop. @@ -789,34 +1168,67 @@ public class HTTPClient { /// Use this for use-cases where you prefer a new connection to be established over re-using an existing /// connection that might be on a different `EventLoop`. public static func delegateAndChannel(on eventLoop: EventLoop) -> EventLoopPreference { - return EventLoopPreference(.delegateAndChannel(on: eventLoop)) + EventLoopPreference(.delegateAndChannel(on: eventLoop)) } } /// Specifies decompression settings. - public enum Decompression { + public enum Decompression: Sendable { /// Decompression is disabled. case disabled /// Decompression is enabled. case enabled(limit: NIOHTTPDecompression.DecompressionLimit) } + typealias ShutdownCallback = @Sendable (Error?) -> Void + enum State { case upAndRunning - case shuttingDown(requiresCleanClose: Bool, callback: (Error?) -> Void) + case shuttingDown(requiresCleanClose: Bool, callback: ShutdownCallback) case shutDown } } +extension HTTPClient.EventLoopGroupProvider { + /// Shares ``HTTPClient/defaultEventLoopGroup`` which is a singleton `EventLoopGroup` suitable for the platform. + public static var singleton: Self { + .shared(HTTPClient.defaultEventLoopGroup) + } +} + +extension HTTPClient { + /// Returns the default `EventLoopGroup` singleton, automatically selecting the best for the platform. + /// + /// This will select the concrete `EventLoopGroup` depending which platform this is running on. + public static var defaultEventLoopGroup: EventLoopGroup { + #if canImport(Network) + if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *) { + return NIOTSEventLoopGroup.singleton + } else { + return MultiThreadedEventLoopGroup.singleton + } + #else + return MultiThreadedEventLoopGroup.singleton + #endif + } +} + +extension HTTPClient.Configuration: Sendable {} + +extension HTTPClient.EventLoopGroupProvider: Sendable {} +extension HTTPClient.EventLoopPreference: Sendable {} + extension HTTPClient.Configuration { /// Timeout configuration. - public struct Timeout { - /// Specifies connect timeout. If no connect timeout is given, a default 30 seconds timeout will applied. + public struct Timeout: Sendable { + /// Specifies connect timeout. If no connect timeout is given, a default 10 seconds timeout will be applied. public var connect: TimeAmount? /// Specifies read timeout. public var read: TimeAmount? + /// Specifies the maximum amount of time without bytes being written by the client before closing the connection. + public var write: TimeAmount? - /// internal connection creation timeout. Defaults the connect timeout to always contain a value. + /// Internal connection creation timeout. Defaults the connect timeout to always contain a value. var connectionCreationTimeout: TimeAmount { self.connect ?? .seconds(10) } @@ -824,17 +1236,35 @@ extension HTTPClient.Configuration { /// Create timeout. /// /// - parameters: - /// - connect: `connect` timeout. Will default to 10 seconds, if no value is - /// provided. See `var connectionCreationTimeout` + /// - connect: `connect` timeout. Will default to 10 seconds, if no value is provided. /// - read: `read` timeout. - public init(connect: TimeAmount? = nil, read: TimeAmount? = nil) { + public init( + connect: TimeAmount? = nil, + read: TimeAmount? = nil + ) { self.connect = connect self.read = read } + + /// Create timeout. + /// + /// - parameters: + /// - connect: `connect` timeout. Will default to 10 seconds, if no value is provided. + /// - read: `read` timeout. + /// - write: `write` timeout. + public init( + connect: TimeAmount? = nil, + read: TimeAmount? = nil, + write: TimeAmount + ) { + self.connect = connect + self.read = read + self.write = write + } } /// Specifies redirect processing settings. - public struct RedirectConfiguration { + public struct RedirectConfiguration: Sendable { enum Mode { /// Redirects are not followed. case disallow @@ -862,21 +1292,45 @@ extension HTTPClient.Configuration { /// - allowCycles: Whether cycles are allowed. /// /// - warning: Cycle detection will keep all visited URLs in memory which means a malicious server could use this as a denial-of-service vector. - public static func follow(max: Int, allowCycles: Bool) -> RedirectConfiguration { return .init(configuration: .follow(max: max, allowCycles: allowCycles)) } + public static func follow(max: Int, allowCycles: Bool) -> RedirectConfiguration { + .init(configuration: .follow(max: max, allowCycles: allowCycles)) + } } /// Connection pool configuration. - public struct ConnectionPool: Hashable { + public struct ConnectionPool: Hashable, Sendable { /// Specifies amount of time connections are kept idle in the pool. After this time has passed without a new /// request the connections are closed. - public var idleTimeout: TimeAmount + public var idleTimeout: TimeAmount = .seconds(60) /// The maximum number of connections that are kept alive in the connection pool per host. If requests with /// an explicit eventLoopRequirement are sent, this number might be exceeded due to overflow connections. - public var concurrentHTTP1ConnectionsPerHostSoftLimit: Int + public var concurrentHTTP1ConnectionsPerHostSoftLimit: Int = 8 + + /// If true, ``HTTPClient`` will try to create new connections on connection failure with an exponential backoff. + /// Requests will only fail after the ``HTTPClient/Configuration/Timeout-swift.struct/connect`` timeout exceeded. + /// If false, all requests that have no assigned connection will fail immediately after a connection could not be established. + /// Defaults to `true`. + /// - warning: We highly recommend leaving this on. + /// It is very common that connections establishment is flaky at scale. + /// ``HTTPClient`` will automatically mitigate these kind of issues if this flag is turned on. + public var retryConnectionEstablishment: Bool = true + + /// The number of pre-warmed HTTP/1 connections to maintain. + /// + /// When set to a number greater than zero, any HTTP/1 connection pool created will attempt to maintain + /// at least this number of "extra" idle connections, above the connections currently in use, up to the + /// limit of ``concurrentHTTP1ConnectionsPerHostSoftLimit``. + /// + /// These connections will not be made while the pool is idle: only once the first connection is made + /// to a host will the others be opened. In addition, to manage the connection creation rate and + /// avoid flooding servers, prewarmed connection creation will be done one-at-a-time. + public var preWarmedHTTP1ConnectionCount: Int = 0 - public init(idleTimeout: TimeAmount = .seconds(60)) { - self.init(idleTimeout: idleTimeout, concurrentHTTP1ConnectionsPerHostSoftLimit: 8) + public init() {} + + public init(idleTimeout: TimeAmount) { + self.idleTimeout = idleTimeout } public init(idleTimeout: TimeAmount, concurrentHTTP1ConnectionsPerHostSoftLimit: Int) { @@ -885,19 +1339,19 @@ extension HTTPClient.Configuration { } } - public struct HTTPVersion { - internal enum Configuration { + public struct HTTPVersion: Sendable, Hashable { + enum Configuration { case http1Only case automatic } - /// we only use HTTP/1, even if the server would supports HTTP/2 + /// We will only use HTTP/1, even if the server would supports HTTP/2 public static let http1Only: Self = .init(configuration: .http1Only) /// HTTP/2 is used if we connect to a server with HTTPS and the server supports HTTP/2, otherwise we use HTTP/1 public static let automatic: Self = .init(configuration: .automatic) - internal var configuration: Configuration + var configuration: Configuration } } @@ -911,6 +1365,7 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { case emptyScheme case unsupportedScheme(String) case readTimeout + case writeTimeout case remoteConnectionClosed case cancelled case identityCodingIncorrectlyPresent @@ -924,6 +1379,7 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { case uncleanShutdown case traceRequestWithBody case invalidHeaderFieldNames([String]) + case invalidHeaderFieldValues([String]) case bodyLengthMismatch case writeAfterRequestSent @available(*, deprecated, message: "AsyncHTTPClient now silently corrects invalid headers.") @@ -937,6 +1393,7 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { case getConnectionFromPoolTimeout case deadlineExceeded case httpEndReceivedAfterHeadWith1xx + case shutdownUnsupported } private var code: Code @@ -946,7 +1403,83 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { } public var description: String { - return "HTTPClientError.\(String(describing: self.code))" + "HTTPClientError.\(String(describing: self.code))" + } + + /// Short description of the error that can be used in case a bounded set of error descriptions is expected, e.g. to + /// include in metric labels. For this reason the description must not contain associated values. + public var shortDescription: String { + // When adding new cases here, do *not* include dynamic (associated) values in the description. + switch self.code { + case .invalidURL: + return "Invalid URL" + case .emptyHost: + return "Empty host" + case .missingSocketPath: + return "Missing socket path" + case .alreadyShutdown: + return "Already shutdown" + case .emptyScheme: + return "Empty scheme" + case .unsupportedScheme: + return "Unsupported scheme" + case .readTimeout: + return "Read timeout" + case .writeTimeout: + return "Write timeout" + case .remoteConnectionClosed: + return "Remote connection closed" + case .cancelled: + return "Cancelled" + case .identityCodingIncorrectlyPresent: + return "Identity coding incorrectly present" + case .chunkedSpecifiedMultipleTimes: + return "Chunked specified multiple times" + case .invalidProxyResponse: + return "Invalid proxy response" + case .contentLengthMissing: + return "Content length missing" + case .proxyAuthenticationRequired: + return "Proxy authentication required" + case .redirectLimitReached: + return "Redirect limit reached" + case .redirectCycleDetected: + return "Redirect cycle detected" + case .uncleanShutdown: + return "Unclean shutdown" + case .traceRequestWithBody: + return "Trace request with body" + case .invalidHeaderFieldNames: + return "Invalid header field names" + case .invalidHeaderFieldValues: + return "Invalid header field values" + case .bodyLengthMismatch: + return "Body length mismatch" + case .writeAfterRequestSent: + return "Write after request sent" + case .incompatibleHeaders: + return "Incompatible headers" + case .connectTimeout: + return "Connect timeout" + case .socksHandshakeTimeout: + return "SOCKS handshake timeout" + case .httpProxyHandshakeTimeout: + return "HTTP proxy handshake timeout" + case .tlsHandshakeTimeout: + return "TLS handshake timeout" + case .serverOfferedUnsupportedApplicationProtocol: + return "Server offered unsupported application protocol" + case .requestStreamCancelled: + return "Request stream cancelled" + case .getConnectionFromPoolTimeout: + return "Get connection from pool timeout" + case .deadlineExceeded: + return "Deadline exceeded" + case .httpEndReceivedAfterHeadWith1xx: + return "HTTP end received after head with 1xx" + case .shutdownUnsupported: + return "The global singleton HTTP client cannot be shut down" + } } /// URL provided is invalid. @@ -960,9 +1493,13 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { /// URL does not contain scheme. public static let emptyScheme = HTTPClientError(code: .emptyScheme) /// Provided URL scheme is not supported, supported schemes are: `http` and `https` - public static func unsupportedScheme(_ scheme: String) -> HTTPClientError { return HTTPClientError(code: .unsupportedScheme(scheme)) } - /// Request timed out. + public static func unsupportedScheme(_ scheme: String) -> HTTPClientError { + HTTPClientError(code: .unsupportedScheme(scheme)) + } + /// Request timed out while waiting for response. public static let readTimeout = HTTPClientError(code: .readTimeout) + /// Request timed out. + public static let writeTimeout = HTTPClientError(code: .writeTimeout) /// Remote connection was closed unexpectedly. public static let remoteConnectionClosed = HTTPClientError(code: .remoteConnectionClosed) /// Request was cancelled. @@ -987,7 +1524,13 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { /// A body was sent in a request with method TRACE. public static let traceRequestWithBody = HTTPClientError(code: .traceRequestWithBody) /// Header field names contain invalid characters. - public static func invalidHeaderFieldNames(_ names: [String]) -> HTTPClientError { return HTTPClientError(code: .invalidHeaderFieldNames(names)) } + public static func invalidHeaderFieldNames(_ names: [String]) -> HTTPClientError { + HTTPClientError(code: .invalidHeaderFieldNames(names)) + } + /// Header field values contain invalid characters. + public static func invalidHeaderFieldValues(_ values: [String]) -> HTTPClientError { + HTTPClientError(code: .invalidHeaderFieldValues(values)) + } /// Body length is not equal to `Content-Length`. public static let bodyLengthMismatch = HTTPClientError(code: .bodyLengthMismatch) /// Body part was written after request was fully sent. @@ -1005,7 +1548,12 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { public static let tlsHandshakeTimeout = HTTPClientError(code: .tlsHandshakeTimeout) /// The remote server only offered an unsupported application protocol public static func serverOfferedUnsupportedApplicationProtocol(_ proto: String) -> HTTPClientError { - return HTTPClientError(code: .serverOfferedUnsupportedApplicationProtocol(proto)) + HTTPClientError(code: .serverOfferedUnsupportedApplicationProtocol(proto)) + } + + /// The globally shared singleton ``HTTPClient`` cannot be shut down. + public static var shutdownUnsupported: HTTPClientError { + HTTPClientError(code: .shutdownUnsupported) } /// The request deadline was exceeded. The request was cancelled because of this. @@ -1022,6 +1570,11 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { /// - Tasks are not processed fast enough on the existing connections, to process all waiters in time public static let getConnectionFromPoolTimeout = HTTPClientError(code: .getConnectionFromPoolTimeout) - @available(*, deprecated, message: "AsyncHTTPClient now correctly supports informational headers. For this reason `httpEndReceivedAfterHeadWith1xx` will not be thrown anymore.") + @available( + *, + deprecated, + message: + "AsyncHTTPClient now correctly supports informational headers. For this reason `httpEndReceivedAfterHeadWith1xx` will not be thrown anymore." + ) public static let httpEndReceivedAfterHeadWith1xx = HTTPClientError(code: .httpEndReceivedAfterHeadWith1xx) } diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift index c1ce39632..20df597ca 100644 --- a/Sources/AsyncHTTPClient/HTTPHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPHandler.swift @@ -12,25 +12,31 @@ // //===----------------------------------------------------------------------===// +import Algorithms import Foundation import Logging import NIOConcurrencyHelpers import NIOCore import NIOHTTP1 +import NIOPosix import NIOSSL +import Tracing extension HTTPClient { - /// Represent request body. - public struct Body { - /// Chunk provider. - public struct StreamWriter { - let closure: (IOData) -> EventLoopFuture + /// A request body. + public struct Body: Sendable { + /// A streaming uploader. + /// + /// ``StreamWriter`` abstracts + public struct StreamWriter: Sendable { + let closure: @Sendable (IOData) -> EventLoopFuture - /// Create new StreamWriter + /// Create new ``HTTPClient/Body/StreamWriter`` /// /// - parameters: /// - closure: function that will be called to write actual bytes to the channel. - public init(closure: @escaping (IOData) -> EventLoopFuture) { + @preconcurrency + public init(closure: @escaping @Sendable (IOData) -> EventLoopFuture) { self.closure = closure } @@ -39,19 +45,101 @@ extension HTTPClient { /// - parameters: /// - data: `IOData` to write. public func write(_ data: IOData) -> EventLoopFuture { - return self.closure(data) + self.closure(data) + } + + @inlinable + func writeChunks( + of bytes: Bytes, + maxChunkSize: Int + ) -> EventLoopFuture where Bytes.Element == UInt8, Bytes: Sendable { + // `StreamWriter` has design issues, for example + // - https://github.com/swift-server/async-http-client/issues/194 + // - https://github.com/swift-server/async-http-client/issues/264 + // - We're not told the EventLoop the task runs on and the user is free to return whatever EL they + // want. + // One important consideration then is that we must lock around the iterator because we could be hopping + // between threads. + typealias Iterator = EnumeratedSequence>.Iterator + typealias Chunk = (offset: Int, element: ChunksOfCountCollection.Element) + + // HACK (again, we're not told the right EventLoop): Let's write 0 bytes to make the user tell us... + return self.write(.byteBuffer(ByteBuffer())).flatMapWithEventLoop { (_, loop) in + func makeIteratorAndFirstChunk( + bytes: Bytes + ) -> (iterator: Iterator, chunk: Chunk)? { + var iterator = bytes.chunks(ofCount: maxChunkSize).enumerated().makeIterator() + guard let chunk = iterator.next() else { + return nil + } + + return (iterator, chunk) + } + + guard let iteratorAndChunk = makeIteratorAndFirstChunk(bytes: bytes) else { + return loop.makeSucceededVoidFuture() + } + + var iterator = iteratorAndChunk.0 + let chunk = iteratorAndChunk.1 + + // can't use closure here as we recursively call ourselves which closures can't do + func writeNextChunk(_ chunk: Chunk, allDone: EventLoopPromise) { + let loop = allDone.futureResult.eventLoop + loop.assertInEventLoop() + + if let (index, element) = iterator.next() { + self.write(.byteBuffer(ByteBuffer(bytes: chunk.element))).hop(to: loop).assumeIsolated().map + { + if (index + 1) % 4 == 0 { + // Let's not stack-overflow if the futures insta-complete which they at least in HTTP/2 + // mode. + // Also, we must frequently return to the EventLoop because we may get the pause signal + // from another thread. If we fail to do that promptly, we may balloon our body chunks + // into memory. + allDone.futureResult.eventLoop.assumeIsolated().execute { + writeNextChunk((offset: index, element: element), allDone: allDone) + } + } else { + writeNextChunk((offset: index, element: element), allDone: allDone) + } + }.nonisolated().cascadeFailure(to: allDone) + } else { + self.write(.byteBuffer(ByteBuffer(bytes: chunk.element))).cascade(to: allDone) + } + } + + let allDone = loop.makePromise(of: Void.self) + writeNextChunk(chunk, allDone: allDone) + return allDone.futureResult + } } } - /// Body size. if nil,`Transfer-Encoding` will automatically be set to `chunked`. Otherwise a `Content-Length` + /// Body size. If nil,`Transfer-Encoding` will automatically be set to `chunked`. Otherwise a `Content-Length` /// header is set with the given `length`. - public var length: Int? + @available(*, deprecated, renamed: "contentLength") + public var length: Int? { + get { + self.contentLength.flatMap { Int($0) } + } + set { + self.contentLength = newValue.flatMap { Int64($0) } + } + } + + /// Body size. If nil,`Transfer-Encoding` will automatically be set to `chunked`. Otherwise a `Content-Length` + /// header is set with the given `contentLength`. + public var contentLength: Int64? + /// Body chunk provider. - public var stream: (StreamWriter) -> EventLoopFuture + public var stream: @Sendable (StreamWriter) -> EventLoopFuture + + @usableFromInline typealias StreamCallback = @Sendable (StreamWriter) -> EventLoopFuture @inlinable - init(length: Int?, stream: @escaping (StreamWriter) -> EventLoopFuture) { - self.length = length + init(contentLength: Int64?, stream: @escaping StreamCallback) { + self.contentLength = contentLength.flatMap { $0 } self.stream = stream } @@ -60,29 +148,53 @@ extension HTTPClient { /// - parameters: /// - buffer: Body `ByteBuffer` representation. public static func byteBuffer(_ buffer: ByteBuffer) -> Body { - return Body(length: buffer.readableBytes) { writer in + Body(contentLength: Int64(buffer.readableBytes)) { writer in writer.write(.byteBuffer(buffer)) } } - /// Create and stream body using `StreamWriter`. + /// Create and stream body using ``StreamWriter``. /// /// - parameters: /// - length: Body size. If nil, `Transfer-Encoding` will automatically be set to `chunked`. Otherwise a `Content-Length` /// header is set with the given `length`. /// - stream: Body chunk provider. - public static func stream(length: Int? = nil, _ stream: @escaping (StreamWriter) -> EventLoopFuture) -> Body { - return Body(length: length, stream: stream) + @_disfavoredOverload + @preconcurrency + public static func stream( + length: Int? = nil, + _ stream: @Sendable @escaping (StreamWriter) -> EventLoopFuture + ) -> Body { + Body(contentLength: length.flatMap { Int64($0) }, stream: stream) + } + + /// Create and stream body using ``StreamWriter``. + /// + /// - parameters: + /// - contentLength: Body size. If nil, `Transfer-Encoding` will automatically be set to `chunked`. Otherwise a `Content-Length` + /// header is set with the given `contentLength`. + /// - stream: Body chunk provider. + public static func stream( + contentLength: Int64? = nil, + _ stream: @Sendable @escaping (StreamWriter) -> EventLoopFuture + ) -> Body { + Body(contentLength: contentLength, stream: stream) } /// Create and stream body using a collection of bytes. /// /// - parameters: - /// - data: Body binary representation. + /// - bytes: Body binary representation. + @preconcurrency @inlinable - public static func bytes(_ bytes: Bytes) -> Body where Bytes: RandomAccessCollection, Bytes.Element == UInt8 { - return Body(length: bytes.count) { writer in - writer.write(.byteBuffer(ByteBuffer(bytes: bytes))) + public static func bytes(_ bytes: Bytes) -> Body + where Bytes: RandomAccessCollection, Bytes: Sendable, Bytes.Element == UInt8 { + Body(contentLength: Int64(bytes.count)) { writer in + if bytes.count <= bagOfBytesToByteBufferConversionChunkSize { + return writer.write(.byteBuffer(ByteBuffer(bytes: bytes))) + } else { + return writer.writeChunks(of: bytes, maxChunkSize: bagOfBytesToByteBufferConversionChunkSize) + } } } @@ -91,14 +203,18 @@ extension HTTPClient { /// - parameters: /// - string: Body `String` representation. public static func string(_ string: String) -> Body { - return Body(length: string.utf8.count) { writer in - writer.write(.byteBuffer(ByteBuffer(string: string))) + Body(contentLength: Int64(string.utf8.count)) { writer in + if string.utf8.count <= bagOfBytesToByteBufferConversionChunkSize { + return writer.write(.byteBuffer(ByteBuffer(string: string))) + } else { + return writer.writeChunks(of: string.utf8, maxChunkSize: bagOfBytesToByteBufferConversionChunkSize) + } } } } - /// Represent HTTP request. - public struct Request { + /// Represents an HTTP request. + public struct Request: Sendable { /// Request HTTP method, defaults to `GET`. public let method: HTTPMethod /// Remote URL. @@ -123,7 +239,6 @@ extension HTTPClient { /// /// - parameters: /// - url: Remote `URL`. - /// - version: HTTP version. /// - method: HTTP method. /// - headers: Custom HTTP headers. /// - body: Request body. @@ -132,7 +247,12 @@ extension HTTPClient { /// - `emptyScheme` if URL does not contain HTTP scheme. /// - `unsupportedScheme` if URL does contains unsupported HTTP scheme. /// - `emptyHost` if URL does not contains a host. - public init(url: String, method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil) throws { + public init( + url: String, + method: HTTPMethod = .GET, + headers: HTTPHeaders = HTTPHeaders(), + body: Body? = nil + ) throws { try self.init(url: url, method: method, headers: headers, body: body, tlsConfiguration: nil) } @@ -140,7 +260,6 @@ extension HTTPClient { /// /// - parameters: /// - url: Remote `URL`. - /// - version: HTTP version. /// - method: HTTP method. /// - headers: Custom HTTP headers. /// - body: Request body. @@ -150,7 +269,13 @@ extension HTTPClient { /// - `emptyScheme` if URL does not contain HTTP scheme. /// - `unsupportedScheme` if URL does contains unsupported HTTP scheme. /// - `emptyHost` if URL does not contains a host. - public init(url: String, method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil, tlsConfiguration: TLSConfiguration?) throws { + public init( + url: String, + method: HTTPMethod = .GET, + headers: HTTPHeaders = HTTPHeaders(), + body: Body? = nil, + tlsConfiguration: TLSConfiguration? + ) throws { guard let url = URL(string: url) else { throw HTTPClientError.invalidURL } @@ -170,7 +295,8 @@ extension HTTPClient { /// - `unsupportedScheme` if URL does contains unsupported HTTP scheme. /// - `emptyHost` if URL does not contains a host. /// - `missingSocketPath` if URL does not contains a socketPath as an encoded host. - public init(url: URL, method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil) throws { + public init(url: URL, method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil) throws + { try self.init(url: url, method: method, headers: headers, body: body, tlsConfiguration: nil) } @@ -187,7 +313,13 @@ extension HTTPClient { /// - `unsupportedScheme` if URL does contains unsupported HTTP scheme. /// - `emptyHost` if URL does not contains a host. /// - `missingSocketPath` if URL does not contains a socketPath as an encoded host. - public init(url: URL, method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil, tlsConfiguration: TLSConfiguration?) throws { + public init( + url: URL, + method: HTTPMethod = .GET, + headers: HTTPHeaders = HTTPHeaders(), + body: Body? = nil, + tlsConfiguration: TLSConfiguration? + ) throws { self.deconstructedURL = try DeconstructedURL(url: url) self.url = url @@ -220,14 +352,26 @@ extension HTTPClient { head.headers.addHostIfNeeded(for: self.deconstructedURL) - let metadata = try head.headers.validateAndSetTransportFraming(method: self.method, bodyLength: .init(self.body)) + let metadata = try head.headers.validateAndSetTransportFraming( + method: self.method, + bodyLength: .init(self.body) + ) return (head, metadata) } + + /// Set basic auth for a request. + /// + /// - parameters: + /// - username: the username to authenticate with + /// - password: authentication password associated with the username + public mutating func setBasicAuth(username: String, password: String) { + self.headers.setBasicAuth(username: username, password: password) + } } - /// Represent HTTP response. - public struct Response { + /// Represents an HTTP response. + public struct Response: Sendable { /// Remote host of the request. public var host: String /// Response HTTP status. @@ -238,6 +382,13 @@ extension HTTPClient { public var headers: HTTPHeaders /// Response body. public var body: ByteBuffer? + /// The history of all requests and responses in redirect order. + public var history: [RequestResponse] + + /// The target URL (after redirects) of the response. + public var url: URL? { + self.history.last?.request.url + } /// Create HTTP `Response`. /// @@ -253,6 +404,7 @@ extension HTTPClient { self.version = HTTPVersion(major: 1, minor: 1) self.headers = headers self.body = body + self.history = [] } /// Create HTTP `Response`. @@ -263,17 +415,49 @@ extension HTTPClient { /// - version: Response HTTP version. /// - headers: Reponse HTTP headers. /// - body: Response body. - public init(host: String, status: HTTPResponseStatus, version: HTTPVersion, headers: HTTPHeaders, body: ByteBuffer?) { + public init( + host: String, + status: HTTPResponseStatus, + version: HTTPVersion, + headers: HTTPHeaders, + body: ByteBuffer? + ) { self.host = host self.status = status self.version = version self.headers = headers self.body = body + self.history = [] + } + + /// Create HTTP `Response`. + /// + /// - parameters: + /// - host: Remote host of the request. + /// - status: Response HTTP status. + /// - version: Response HTTP version. + /// - headers: Reponse HTTP headers. + /// - body: Response body. + /// - history: History of all requests and responses in redirect order. + public init( + host: String, + status: HTTPResponseStatus, + version: HTTPVersion, + headers: HTTPHeaders, + body: ByteBuffer?, + history: [RequestResponse] + ) { + self.host = host + self.status = status + self.version = version + self.headers = headers + self.body = body + self.history = history } } - /// HTTP authentication - public struct Authorization: Hashable { + /// HTTP authentication. + public struct Authorization: Hashable, Sendable { private enum Scheme: Hashable { case Basic(String) case Bearer(String) @@ -285,18 +469,24 @@ extension HTTPClient { self.scheme = scheme } + /// HTTP basic auth. public static func basic(username: String, password: String) -> HTTPClient.Authorization { - return .basic(credentials: Base64.encode(bytes: "\(username):\(password)".utf8)) + .basic(credentials: Base64.encode(bytes: "\(username):\(password)".utf8)) } + /// HTTP basic auth. + /// + /// This version uses the raw string directly. public static func basic(credentials: String) -> HTTPClient.Authorization { - return .init(scheme: .Basic(credentials)) + .init(scheme: .Basic(credentials)) } + /// HTTP bearer auth public static func bearer(tokens: String) -> HTTPClient.Authorization { - return .init(scheme: .Bearer(tokens)) + .init(scheme: .Bearer(tokens)) } + /// The header string for this auth field. public var headerValue: String { switch self.scheme { case .Basic(let credentials): @@ -306,9 +496,23 @@ extension HTTPClient { } } } + + public struct RequestResponse: Sendable { + public var request: Request + public var responseHead: HTTPResponseHead + + public init(request: Request, responseHead: HTTPResponseHead) { + self.request = request + self.responseHead = responseHead + } + } } -public class ResponseAccumulator: HTTPClientResponseDelegate { +/// The default ``HTTPClientResponseDelegate``. +/// +/// This ``HTTPClientResponseDelegate`` buffers a complete HTTP response in memory. It does not stream the response body in. +/// The resulting ``Response`` type is ``HTTPClient/Response``. +public final class ResponseAccumulator: HTTPClientResponseDelegate { public typealias Response = HTTPClient.Response enum State { @@ -319,103 +523,214 @@ public class ResponseAccumulator: HTTPClientResponseDelegate { case error(Error) } - var state = State.idle - let request: HTTPClient.Request + public struct ResponseTooBigError: Error, CustomStringConvertible { + public var maxBodySize: Int + public init(maxBodySize: Int) { + self.maxBodySize = maxBodySize + } + + public var description: String { + "ResponseTooBigError: received response body exceeds maximum accepted size of \(self.maxBodySize) bytes" + } + } + + private struct MutableState: Sendable { + var history = [HTTPClient.RequestResponse]() + var state = State.idle + } + + private let state: NIOLockedValueBox + let requestMethod: HTTPMethod + let requestHost: String + + static let maxByteBufferSize = Int(UInt32.max) + + /// Maximum size in bytes of the HTTP response body that ``ResponseAccumulator`` will accept + /// until it will abort the request and throw an ``ResponseTooBigError``. + /// + /// Default is 2^32. + /// - precondition: not allowed to exceed 2^32 because `ByteBuffer` can not store more bytes + public let maxBodySize: Int - public init(request: HTTPClient.Request) { - self.request = request + public convenience init(request: HTTPClient.Request) { + self.init(request: request, maxBodySize: Self.maxByteBufferSize) + } + + /// - Parameters: + /// - request: The corresponding request of the response this delegate will be accumulating. + /// - maxBodySize: Maximum size in bytes of the HTTP response body that ``ResponseAccumulator`` will accept + /// until it will abort the request and throw an ``ResponseTooBigError``. + /// Default is 2^32. + /// - precondition: maxBodySize is not allowed to exceed 2^32 because `ByteBuffer` can not store more bytes + /// - warning: You can use ``ResponseAccumulator`` for just one request. + /// If you start another request, you need to initiate another ``ResponseAccumulator``. + public init(request: HTTPClient.Request, maxBodySize: Int) { + precondition(maxBodySize >= 0, "maxBodyLength is not allowed to be negative") + precondition( + maxBodySize <= Self.maxByteBufferSize, + "maxBodyLength is not allowed to exceed 2^32 because ByteBuffer can not store more bytes" + ) + self.requestMethod = request.method + self.requestHost = request.host + self.maxBodySize = maxBodySize + self.state = NIOLockedValueBox(MutableState()) + } + + public func didVisitURL( + task: HTTPClient.Task, + _ request: HTTPClient.Request, + _ head: HTTPResponseHead + ) { + self.state.withLockedValue { + $0.history.append(.init(request: request, responseHead: head)) + } } public func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { - switch self.state { - case .idle: - self.state = .head(head) - case .head: - preconditionFailure("head already set") - case .body: - preconditionFailure("no head received before body") - case .end: - preconditionFailure("request already processed") - case .error: - break - } - return task.eventLoop.makeSucceededFuture(()) + let responseTooBig: Bool + + if self.requestMethod != .HEAD, + let contentLength = head.headers.first(name: "Content-Length"), + let announcedBodySize = Int(contentLength), + announcedBodySize > self.maxBodySize + { + responseTooBig = true + } else { + responseTooBig = false + } + + return self.state.withLockedValue { + switch $0.state { + case .idle: + if responseTooBig { + let error = ResponseTooBigError(maxBodySize: self.maxBodySize) + $0.state = .error(error) + return task.eventLoop.makeFailedFuture(error) + } + + $0.state = .head(head) + case .head: + preconditionFailure("head already set") + case .body: + preconditionFailure("no head received before body") + case .end: + preconditionFailure("request already processed") + case .error: + break + } + return task.eventLoop.makeSucceededFuture(()) + } } public func didReceiveBodyPart(task: HTTPClient.Task, _ part: ByteBuffer) -> EventLoopFuture { - switch self.state { - case .idle: - preconditionFailure("no head received before body") - case .head(let head): - self.state = .body(head, part) - case .body(let head, var body): - // The compiler can't prove that `self.state` is dead here (and it kinda isn't, there's - // a cross-module call in the way) so we need to drop the original reference to `body` in - // `self.state` or we'll get a CoW. To fix that we temporarily set the state to `.end` (which - // has no associated data). We'll fix it at the bottom of this block. - self.state = .end - var part = part - body.writeBuffer(&part) - self.state = .body(head, body) - case .end: - preconditionFailure("request already processed") - case .error: - break - } - return task.eventLoop.makeSucceededFuture(()) + self.state.withLockedValue { + switch $0.state { + case .idle: + preconditionFailure("no head received before body") + case .head(let head): + guard part.readableBytes <= self.maxBodySize else { + let error = ResponseTooBigError(maxBodySize: self.maxBodySize) + $0.state = .error(error) + return task.eventLoop.makeFailedFuture(error) + } + $0.state = .body(head, part) + case .body(let head, var body): + let newBufferSize = body.writerIndex + part.readableBytes + guard newBufferSize <= self.maxBodySize else { + let error = ResponseTooBigError(maxBodySize: self.maxBodySize) + $0.state = .error(error) + return task.eventLoop.makeFailedFuture(error) + } + + // The compiler can't prove that `self.state` is dead here (and it kinda isn't, there's + // a cross-module call in the way) so we need to drop the original reference to `body` in + // `self.state` or we'll get a CoW. To fix that we temporarily set the state to `.end` (which + // has no associated data). We'll fix it at the bottom of this block. + $0.state = .end + var part = part + body.writeBuffer(&part) + $0.state = .body(head, body) + case .end: + preconditionFailure("request already processed") + case .error: + break + } + return task.eventLoop.makeSucceededFuture(()) + } } public func didReceiveError(task: HTTPClient.Task, _ error: Error) { - self.state = .error(error) + self.state.withLockedValue { + $0.state = .error(error) + } } public func didFinishRequest(task: HTTPClient.Task) throws -> Response { - switch self.state { - case .idle: - preconditionFailure("no head received before end") - case .head(let head): - return Response(host: self.request.host, status: head.status, version: head.version, headers: head.headers, body: nil) - case .body(let head, let body): - return Response(host: self.request.host, status: head.status, version: head.version, headers: head.headers, body: body) - case .end: - preconditionFailure("request already processed") - case .error(let error): - throw error + try self.state.withLockedValue { + switch $0.state { + case .idle: + preconditionFailure("no head received before end") + case .head(let head): + return Response( + host: self.requestHost, + status: head.status, + version: head.version, + headers: head.headers, + body: nil, + history: $0.history + ) + case .body(let head, let body): + return Response( + host: self.requestHost, + status: head.status, + version: head.version, + headers: head.headers, + body: body, + history: $0.history + ) + case .end: + preconditionFailure("request already processed") + case .error(let error): + throw error + } } } } -/// `HTTPClientResponseDelegate` allows an implementation to receive notifications about request processing and to control how response parts are processed. +/// ``HTTPClientResponseDelegate`` allows an implementation to receive notifications about request processing and to control how response parts are processed. +/// /// You can implement this protocol if you need fine-grained control over an HTTP request/response, for example, if you want to inspect the response /// headers before deciding whether to accept a response body, or if you want to stream your request body. Pass an instance of your conforming -/// class to the `HTTPClient.execute()` method and this package will call each delegate method appropriately as the request takes place./ +/// class to the ``HTTPClient/execute(request:delegate:eventLoop:deadline:)`` method and this package will call each delegate method appropriately as the request takes place. /// /// ### Backpressure /// -/// A `HTTPClientResponseDelegate` can be used to exert backpressure on the server response. This is achieved by way of the futures returned from -/// `didReceiveHead` and `didReceiveBodyPart`. The following functions are part of the "backpressure system" in the delegate: +/// A ``HTTPClientResponseDelegate`` can be used to exert backpressure on the server response. This is achieved by way of the futures returned from +/// ``HTTPClientResponseDelegate/didReceiveHead(task:_:)-9r4xd`` and ``HTTPClientResponseDelegate/didReceiveBodyPart(task:_:)-4fd4v``. +/// The following functions are part of the "backpressure system" in the delegate: /// -/// - `didReceiveHead` -/// - `didReceiveBodyPart` -/// - `didFinishRequest` -/// - `didReceiveError` +/// - ``HTTPClientResponseDelegate/didReceiveHead(task:_:)-9r4xd`` +/// - ``HTTPClientResponseDelegate/didReceiveBodyPart(task:_:)-4fd4v`` +/// - ``HTTPClientResponseDelegate/didFinishRequest(task:)`` +/// - ``HTTPClientResponseDelegate/didReceiveError(task:_:)-fhsg`` /// -/// The first three methods are strictly _exclusive_, with that exclusivity managed by the futures returned by `didReceiveHead` and -/// `didReceiveBodyPart`. What this means is that until the returned future is completed, none of these three methods will be called -/// again. This allows delegates to rate limit the server to a capacity it can manage. `didFinishRequest` does not return a future, +/// The first three methods are strictly _exclusive_, with that exclusivity managed by the futures returned by ``HTTPClientResponseDelegate/didReceiveHead(task:_:)-9r4xd`` and +/// ``HTTPClientResponseDelegate/didReceiveBodyPart(task:_:)-4fd4v``. What this means is that until the returned future is completed, none of these three methods will be called +/// again. This allows delegates to rate limit the server to a capacity it can manage. ``HTTPClientResponseDelegate/didFinishRequest(task:)`` does not return a future, /// as we are expecting no more data from the server at this time. /// -/// `didReceiveError` is somewhat special: it signals the end of this regime. `didRecieveError` is not exclusive: it may be called at -/// any time, even if a returned future is not yet completed. `didReceiveError` is terminal, meaning that once it has been called none -/// of these four methods will be called again. This can be used as a signal to abandon all outstanding work. +/// ``HTTPClientResponseDelegate/didReceiveError(task:_:)-fhsg`` is somewhat special: it signals the end of this regime. ``HTTPClientResponseDelegate/didReceiveError(task:_:)-fhsg`` +/// is not exclusive: it may be called at any time, even if a returned future is not yet completed. ``HTTPClientResponseDelegate/didReceiveError(task:_:)-fhsg`` is terminal, meaning +/// that once it has been called none of these four methods will be called again. This can be used as a signal to abandon all outstanding work. /// /// - note: This delegate is strongly held by the `HTTPTaskHandler` -/// for the duration of the `Request` processing and will be +/// for the duration of the ``HTTPClient/Request`` processing and will be /// released together with the `HTTPTaskHandler` when channel is closed. /// Users of the library are not required to keep a reference to the /// object that implements this protocol, but may do so if needed. -public protocol HTTPClientResponseDelegate: AnyObject { - associatedtype Response +@preconcurrency +public protocol HTTPClientResponseDelegate: AnyObject, Sendable { + associatedtype Response: Sendable /// Called when the request head is sent. Will be called once. /// @@ -428,7 +743,7 @@ public protocol HTTPClientResponseDelegate: AnyObject { /// /// - parameters: /// - task: Current request context. - /// - part: Request body `Part`. + /// - part: Request body part. func didSendRequestPart(task: HTTPClient.Task, _ part: IOData) /// Called when the request is fully sent. Will be called once. @@ -437,7 +752,16 @@ public protocol HTTPClientResponseDelegate: AnyObject { /// - task: Current request context. func didSendRequest(task: HTTPClient.Task) - /// Called when response head is received. Will be called once. + /// Called each time a response head is received (including redirects), and always called before ``HTTPClientResponseDelegate/didReceiveHead(task:_:)-9r4xd``. + /// You can use this method to keep an entire history of the request/response chain. + /// + /// - parameters: + /// - task: Current request context. + /// - request: The request that was sent. + /// - head: Received response head. + func didVisitURL(task: HTTPClient.Task, _ request: HTTPClient.Request, _ head: HTTPResponseHead) + + /// Called when the final response head is received (after redirects). /// You must return an `EventLoopFuture` that you complete when you have finished processing the body part. /// You can create an already succeeded future by calling `task.eventLoop.makeSucceededFuture(())`. /// @@ -451,7 +775,7 @@ public protocol HTTPClientResponseDelegate: AnyObject { /// You must return an `EventLoopFuture` that you complete when you have finished processing the body part. /// You can create an already succeeded future by calling `task.eventLoop.makeSucceededFuture(())`. /// - /// This function will not be called until the future returned by `didReceiveHead` has completed. + /// This function will not be called until the future returned by ``HTTPClientResponseDelegate/didReceiveHead(task:_:)-9r4xd`` has completed. /// /// This function will not be called for subsequent body parts until the previous future returned by a /// call to this function completes. @@ -464,19 +788,22 @@ public protocol HTTPClientResponseDelegate: AnyObject { /// Called when error was thrown during request execution. Will be called zero or one time only. Request processing will be stopped after that. /// - /// This function may be called at any time: it does not respect the backpressure exerted by `didReceiveHead` and `didReceiveBodyPart`. - /// All outstanding work may be cancelled when this is received. Once called, no further calls will be made to `didReceiveHead`, `didReceiveBodyPart`, - /// or `didFinishRequest`. + /// This function may be called at any time: it does not respect the backpressure exerted by ``HTTPClientResponseDelegate/didReceiveHead(task:_:)-9r4xd`` + /// and ``HTTPClientResponseDelegate/didReceiveBodyPart(task:_:)-4fd4v``. + /// All outstanding work may be cancelled when this is received. Once called, no further calls will be made to + /// ``HTTPClientResponseDelegate/didReceiveHead(task:_:)-9r4xd``, ``HTTPClientResponseDelegate/didReceiveBodyPart(task:_:)-4fd4v``, + /// or ``HTTPClientResponseDelegate/didFinishRequest(task:)``. /// /// - parameters: /// - task: Current request context. /// - error: Error that occured during response processing. func didReceiveError(task: HTTPClient.Task, _ error: Error) - /// Called when the complete HTTP request is finished. You must return an instance of your `Response` associated type. Will be called once, except if an error occurred. + /// Called when the complete HTTP request is finished. You must return an instance of your ``Response`` associated type. Will be called once, except if an error occurred. /// - /// This function will not be called until all futures returned by `didReceiveHead` and `didReceiveBodyPart` have completed. Once called, - /// no further calls will be made to `didReceiveHead`, `didReceiveBodyPart`, or `didReceiveError`. + /// This function will not be called until all futures returned by ``HTTPClientResponseDelegate/didReceiveHead(task:_:)-9r4xd`` and ``HTTPClientResponseDelegate/didReceiveBodyPart(task:_:)-4fd4v`` + /// have completed. Once called, no further calls will be made to ``HTTPClientResponseDelegate/didReceiveHead(task:_:)-9r4xd``, ``HTTPClientResponseDelegate/didReceiveBodyPart(task:_:)-4fd4v``, + /// or ``HTTPClientResponseDelegate/didReceiveError(task:_:)-fhsg``. /// /// - parameters: /// - task: Current request context. @@ -485,20 +812,43 @@ public protocol HTTPClientResponseDelegate: AnyObject { } extension HTTPClientResponseDelegate { + /// Default implementation of ``HTTPClientResponseDelegate/didSendRequest(task:)-9od5p``. + /// + /// By default, this does nothing. public func didSendRequestHead(task: HTTPClient.Task, _ head: HTTPRequestHead) {} + /// Default implementation of ``HTTPClientResponseDelegate/didSendRequestPart(task:_:)-4qxap``. + /// + /// By default, this does nothing. public func didSendRequestPart(task: HTTPClient.Task, _ part: IOData) {} + /// Default implementation of ``HTTPClientResponseDelegate/didSendRequest(task:)-3vqgm``. + /// + /// By default, this does nothing. public func didSendRequest(task: HTTPClient.Task) {} + /// Default implementation of ``HTTPClientResponseDelegate/didVisitURL(task:_:_:)-2el9y``. + /// + /// By default, this does nothing. + public func didVisitURL(task: HTTPClient.Task, _: HTTPClient.Request, _: HTTPResponseHead) {} + + /// Default implementation of ``HTTPClientResponseDelegate/didReceiveHead(task:_:)-9r4xd``. + /// + /// By default, this does nothing. public func didReceiveHead(task: HTTPClient.Task, _: HTTPResponseHead) -> EventLoopFuture { - return task.eventLoop.makeSucceededFuture(()) + task.eventLoop.makeSucceededVoidFuture() } + /// Default implementation of ``HTTPClientResponseDelegate/didReceiveBodyPart(task:_:)-4fd4v``. + /// + /// By default, this does nothing. public func didReceiveBodyPart(task: HTTPClient.Task, _: ByteBuffer) -> EventLoopFuture { - return task.eventLoop.makeSucceededFuture(()) + task.eventLoop.makeSucceededVoidFuture() } + /// Default implementation of ``HTTPClientResponseDelegate/didReceiveError(task:_:)-fhsg``. + /// + /// By default, this does nothing. public func didReceiveError(task: HTTPClient.Task, _: Error) {} } @@ -507,7 +857,7 @@ extension URL { if self.path.isEmpty { return "/" } - return URLComponents(url: self, resolvingAgainstBaseURL: false)?.percentEncodedPath ?? self.path + return URLComponents(url: self, resolvingAgainstBaseURL: true)?.percentEncodedPath ?? self.path } var uri: String { @@ -521,7 +871,7 @@ extension URL { } func hasTheSameOrigin(as other: URL) -> Bool { - return self.host == other.host && self.scheme == other.scheme && self.port == other.port + self.host == other.host && self.scheme == other.scheme && self.port == other.port } /// Initializes a newly created HTTP URL connecting to a unix domain socket path. The socket path is encoded as the URL's host, replacing percent encoding invalid path characters, and will use the "http+unix" scheme. @@ -555,81 +905,139 @@ extension URL { } } -protocol HTTPClientTaskDelegate { - func cancel() +protocol HTTPClientTaskDelegate: Sendable { + func fail(_ error: Error) } extension HTTPClient { - /// Response execution context. Will be created by the library and could be used for obtaining + /// Response execution context. + /// + /// Will be created by the library and could be used for obtaining /// `EventLoopFuture` of the execution or cancellation of the execution. - public final class Task { + public final class Task: Sendable { /// The `EventLoop` the delegate will be executed on. public let eventLoop: EventLoop + /// The `Logger` used by the `Task` for logging. + public let logger: Logger // We are okay to store the logger here because a Task is for only one request. + + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public var tracer: (any Tracer)? { + tracing.tracer + } + let tracing: TracingConfiguration let promise: EventLoopPromise - let logger: Logger // We are okay to store the logger here because a Task is for only one request. + + struct State: Sendable { + var isCancelled: Bool + var taskDelegate: HTTPClientTaskDelegate? + } + + private let state: NIOLockedValueBox var isCancelled: Bool { - self.lock.withLock { self._isCancelled } + self.state.withLockedValue { $0.isCancelled } } var taskDelegate: HTTPClientTaskDelegate? { get { - self.lock.withLock { self._taskDelegate } + self.state.withLockedValue { $0.taskDelegate } } set { - self.lock.withLock { self._taskDelegate = newValue } + self.state.withLockedValue { $0.taskDelegate = newValue } } } - private var _isCancelled: Bool = false - private var _taskDelegate: HTTPClientTaskDelegate? - private let lock = Lock() + private let makeOrGetFileIOThreadPool: @Sendable () -> NIOThreadPool - init(eventLoop: EventLoop, logger: Logger) { + /// The shared thread pool of a ``HTTPClient`` used for file IO. It is lazily created on first access. + internal var fileIOThreadPool: NIOThreadPool { + self.makeOrGetFileIOThreadPool() + } + + init( + eventLoop: EventLoop, + logger: Logger, + tracing: TracingConfiguration, + makeOrGetFileIOThreadPool: @escaping @Sendable () -> NIOThreadPool + ) { self.eventLoop = eventLoop self.promise = eventLoop.makePromise() self.logger = logger + self.tracing = tracing + self.makeOrGetFileIOThreadPool = makeOrGetFileIOThreadPool + self.state = NIOLockedValueBox(State(isCancelled: false, taskDelegate: nil)) } - static func failedTask(eventLoop: EventLoop, error: Error, logger: Logger) -> Task { - let task = self.init(eventLoop: eventLoop, logger: logger) + static func failedTask( + eventLoop: EventLoop, + error: Error, + logger: Logger, + tracing: TracingConfiguration, + makeOrGetFileIOThreadPool: @escaping @Sendable () -> NIOThreadPool + ) -> Task { + let task = self.init( + eventLoop: eventLoop, + logger: logger, + tracing: tracing, + makeOrGetFileIOThreadPool: makeOrGetFileIOThreadPool + ) task.promise.fail(error) return task } /// `EventLoopFuture` for the response returned by this request. public var futureResult: EventLoopFuture { - return self.promise.futureResult + self.promise.futureResult } /// Waits for execution of this request to complete. /// - /// - returns: The value of the `EventLoopFuture` when it completes. - /// - throws: The error value of the `EventLoopFuture` if it errors. - public func wait() throws -> Response { - return try self.promise.futureResult.wait() + /// - returns: The value of ``futureResult`` when it completes. + /// - throws: The error value of ``futureResult`` if it errors. + @available(*, noasync, message: "wait() can block indefinitely, prefer get()", renamed: "get()") + @preconcurrency + public func wait() throws -> Response where Response: Sendable { + try self.promise.futureResult.wait() } - /// Cancels the request execution. - public func cancel() { - let taskDelegate = self.lock.withLock { () -> HTTPClientTaskDelegate? in - self._isCancelled = true - return self._taskDelegate - } + /// Provides the result of this request. + /// + /// - warning: This method may violates Structured Concurrency because doesn't respect cancellation. + /// + /// - returns: The value of ``futureResult`` when it completes. + /// - throws: The error value of ``futureResult`` if it errors. + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) + @preconcurrency + public func get() async throws -> Response where Response: Sendable { + try await self.promise.futureResult.get() + } - taskDelegate?.cancel() + /// Initiate cancellation of a HTTP request. + /// + /// This method will return immeidately and doesn't wait for the cancellation to complete. + public func cancel() { + self.fail(reason: HTTPClientError.cancelled) } - func succeed(promise: EventLoopPromise?, - with value: Response, - delegateType: Delegate.Type, - closing: Bool) { - promise?.succeed(value) + /// Initiate cancellation of a HTTP request with an `error`. + /// + /// This method will return immeidately and doesn't wait for the cancellation to complete. + /// + /// - Parameter error: the error that is used to fail the promise + public func fail(reason error: Error) { + let taskDelegate = self.state.withLockedValue { state in + state.isCancelled = true + return state.taskDelegate + } + + taskDelegate?.fail(error) } - func fail(with error: Error, - delegateType: Delegate.Type) { + /// Called internally only, used to fail a task from within the state machine functionality. + func failInternal( + with error: Error + ) { self.promise.fail(error) } } @@ -639,7 +1047,7 @@ internal struct TaskCancelEvent {} // MARK: - RedirectHandler -internal struct RedirectHandler { +internal struct RedirectHandler { let request: HTTPClient.Request let redirectState: RedirectState let execute: (HTTPClient.Request, RedirectState) -> HTTPClient.Task @@ -656,7 +1064,7 @@ internal struct RedirectHandler { status: HTTPResponseStatus, to redirectURL: URL, promise: EventLoopPromise - ) { + ) -> HTTPClient.Task? { do { var redirectState = self.redirectState try redirectState.redirect(to: redirectURL.absoluteString) @@ -676,13 +1084,19 @@ internal struct RedirectHandler { headers: headers, body: body ) - self.execute(newRequest, redirectState).futureResult.whenComplete { result in + + let newTask = self.execute(newRequest, redirectState) + + newTask.futureResult.whenComplete { result in promise.futureResult.eventLoop.execute { promise.completeWith(result) } } + + return newTask } catch { promise.fail(error) + return nil } } } @@ -693,7 +1107,7 @@ extension RequestBodyLength { self = .known(0) return } - guard let length = body.length else { + guard let length = body.contentLength else { self = .unknown return } diff --git a/Sources/AsyncHTTPClient/LRUCache.swift b/Sources/AsyncHTTPClient/LRUCache.swift index 0a01da0d2..f8b58c36a 100644 --- a/Sources/AsyncHTTPClient/LRUCache.swift +++ b/Sources/AsyncHTTPClient/LRUCache.swift @@ -52,9 +52,11 @@ struct LRUCache { @discardableResult mutating func append(key: Key, value: Value) -> Value { - let newElement = Element(generation: self.generation, - key: key, - value: value) + let newElement = Element( + generation: self.generation, + key: key, + value: value + ) if let found = self.bumpGenerationAndFindIndex(key: key) { self.elements[found] = newElement return value diff --git a/Sources/AsyncHTTPClient/NIOLoopBound+Execute.swift b/Sources/AsyncHTTPClient/NIOLoopBound+Execute.swift new file mode 100644 index 000000000..b25a0f00d --- /dev/null +++ b/Sources/AsyncHTTPClient/NIOLoopBound+Execute.swift @@ -0,0 +1,28 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2025 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIOCore + +extension NIOLoopBound { + @inlinable + func execute(_ body: @Sendable @escaping (Value) -> Void) { + if self.eventLoop.inEventLoop { + body(self.value) + } else { + self.eventLoop.execute { + body(self.value) + } + } + } +} diff --git a/Sources/AsyncHTTPClient/NIOTransportServices/NWErrorHandler.swift b/Sources/AsyncHTTPClient/NIOTransportServices/NWErrorHandler.swift index 4334bb9f9..148b4a4c4 100644 --- a/Sources/AsyncHTTPClient/NIOTransportServices/NWErrorHandler.swift +++ b/Sources/AsyncHTTPClient/NIOTransportServices/NWErrorHandler.swift @@ -12,15 +12,17 @@ // //===----------------------------------------------------------------------===// -#if canImport(Network) -import Network -#endif import NIOCore import NIOHTTP1 import NIOTransportServices +#if canImport(Network) +import Network +#endif + extension HTTPClient { #if canImport(Network) + /// A wrapper for `POSIX` errors thrown by `Network.framework`. public struct NWPOSIXError: Error, CustomStringConvertible { /// POSIX error code (enum) public let errorCode: POSIXErrorCode @@ -37,11 +39,12 @@ extension HTTPClient { self.reason = reason } - public var description: String { return self.reason } + public var description: String { self.reason } } + /// A wrapper for TLS errors thrown by `Network.framework`. public struct NWTLSError: Error, CustomStringConvertible { - /// TLS error status. List of TLS errors can be found in + /// TLS error status. List of TLS errors can be found in `` public let status: OSStatus /// actual reason, in human readable form @@ -56,11 +59,11 @@ extension HTTPClient { self.reason = reason } - public var description: String { return self.reason } + public var description: String { self.reason } } #endif - class NWErrorHandler: ChannelInboundHandler { + final class NWErrorHandler: ChannelInboundHandler { typealias InboundIn = HTTPClientResponsePart func errorCaught(context: ChannelHandlerContext, error: Error) { @@ -73,9 +76,9 @@ extension HTTPClient { if let error = error as? NWError { switch error { case .tls(let status): - return NWTLSError(status, reason: error.localizedDescription) + return NWTLSError(status, reason: String(describing: error)) case .posix(let errorCode): - return NWPOSIXError(errorCode, reason: error.localizedDescription) + return NWPOSIXError(errorCode, reason: String(describing: error)) default: return error } diff --git a/Sources/AsyncHTTPClient/NIOTransportServices/NWWaitingHandler.swift b/Sources/AsyncHTTPClient/NIOTransportServices/NWWaitingHandler.swift new file mode 100644 index 000000000..d7c6055ec --- /dev/null +++ b/Sources/AsyncHTTPClient/NIOTransportServices/NWWaitingHandler.swift @@ -0,0 +1,44 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2022 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#if canImport(Network) +import Network +import NIOCore +import NIOHTTP1 +import NIOTransportServices + +@available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *) +final class NWWaitingHandler: ChannelInboundHandler { + typealias InboundIn = Any + typealias InboundOut = Any + + private var requester: Requester + private let connectionID: HTTPConnectionPool.Connection.ID + + init(requester: Requester, connectionID: HTTPConnectionPool.Connection.ID) { + self.requester = requester + self.connectionID = connectionID + } + + func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { + if let waitingEvent = event as? NIOTSNetworkEvents.WaitingForConnectivity { + self.requester.waitingForConnectivity( + self.connectionID, + error: HTTPClient.NWErrorHandler.translateError(waitingEvent.transientError) + ) + } + context.fireUserInboundEventTriggered(event) + } +} +#endif diff --git a/Sources/AsyncHTTPClient/NIOTransportServices/TLSConfiguration.swift b/Sources/AsyncHTTPClient/NIOTransportServices/TLSConfiguration.swift index e20f52634..e8278e095 100644 --- a/Sources/AsyncHTTPClient/NIOTransportServices/TLSConfiguration.swift +++ b/Sources/AsyncHTTPClient/NIOTransportServices/TLSConfiguration.swift @@ -57,20 +57,23 @@ extension TLSVersion { } } -@available(macOS 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *) +@available(macOS 10.14, iOS 12.0, tvOS 12.0, watchOS 5.0, *) extension TLSConfiguration { /// Dispatch queue used by Network framework TLS to control certificate verification - static var tlsDispatchQueue = DispatchQueue(label: "TLSDispatch") + static let tlsDispatchQueue = DispatchQueue(label: "TLSDispatch") /// create NWProtocolTLS.Options for use with NIOTransportServices from the NIOSSL TLSConfiguration /// /// - Parameter eventLoop: EventLoop to wait for creation of options on /// - Returns: Future holding NWProtocolTLS Options - func getNWProtocolTLSOptions(on eventLoop: EventLoop) -> EventLoopFuture { + func getNWProtocolTLSOptions( + on eventLoop: EventLoop, + serverNameIndicatorOverride: String? + ) -> EventLoopFuture { let promise = eventLoop.makePromise(of: NWProtocolTLS.Options.self) Self.tlsDispatchQueue.async { do { - let options = try self.getNWProtocolTLSOptions() + let options = try self.getNWProtocolTLSOptions(serverNameIndicatorOverride: serverNameIndicatorOverride) promise.succeed(options) } catch { promise.fail(error) @@ -82,27 +85,42 @@ extension TLSConfiguration { /// create NWProtocolTLS.Options for use with NIOTransportServices from the NIOSSL TLSConfiguration /// /// - Returns: Equivalent NWProtocolTLS Options - func getNWProtocolTLSOptions() throws -> NWProtocolTLS.Options { + func getNWProtocolTLSOptions(serverNameIndicatorOverride: String?) throws -> NWProtocolTLS.Options { let options = NWProtocolTLS.Options() let useMTELGExplainer = """ - You can still use this configuration option on macOS if you initialize HTTPClient \ - with a MultiThreadedEventLoopGroup. Please note that using MultiThreadedEventLoopGroup \ - will make AsyncHTTPClient use NIO on BSD Sockets and not Network.framework (which is the preferred \ - platform networking stack). - """ + You can still use this configuration option on macOS if you initialize HTTPClient \ + with a MultiThreadedEventLoopGroup. Please note that using MultiThreadedEventLoopGroup \ + will make AsyncHTTPClient use NIO on BSD Sockets and not Network.framework (which is the preferred \ + platform networking stack). + """ + + if let serverNameIndicatorOverride = serverNameIndicatorOverride { + serverNameIndicatorOverride.withCString { serverNameIndicatorOverride in + sec_protocol_options_set_tls_server_name(options.securityProtocolOptions, serverNameIndicatorOverride) + } + } // minimum TLS protocol if #available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) { - sec_protocol_options_set_min_tls_protocol_version(options.securityProtocolOptions, self.minimumTLSVersion.nwTLSProtocolVersion) + sec_protocol_options_set_min_tls_protocol_version( + options.securityProtocolOptions, + self.minimumTLSVersion.nwTLSProtocolVersion + ) } else { - sec_protocol_options_set_tls_min_version(options.securityProtocolOptions, self.minimumTLSVersion.sslProtocol) + sec_protocol_options_set_tls_min_version( + options.securityProtocolOptions, + self.minimumTLSVersion.sslProtocol + ) } // maximum TLS protocol if let maximumTLSVersion = self.maximumTLSVersion { if #available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) { - sec_protocol_options_set_max_tls_protocol_version(options.securityProtocolOptions, maximumTLSVersion.nwTLSProtocolVersion) + sec_protocol_options_set_max_tls_protocol_version( + options.securityProtocolOptions, + maximumTLSVersion.nwTLSProtocolVersion + ) } else { sec_protocol_options_set_tls_max_version(options.securityProtocolOptions, maximumTLSVersion.sslProtocol) } @@ -115,11 +133,6 @@ extension TLSConfiguration { } } - // the certificate chain - if self.certificateChain.count > 0 { - preconditionFailure("TLSConfiguration.certificateChain is not supported. \(useMTELGExplainer)") - } - // cipher suites if self.cipherSuites.count > 0 { // TODO: Requires NIOSSL to provide list of cipher values before we can continue @@ -160,8 +173,10 @@ extension TLSConfiguration { break } - precondition(self.certificateVerification != .noHostnameVerification, - "TLSConfiguration.certificateVerification = .noHostnameVerification is not supported. \(useMTELGExplainer)") + precondition( + self.certificateVerification != .noHostnameVerification, + "TLSConfiguration.certificateVerification = .noHostnameVerification is not supported. \(useMTELGExplainer)" + ) if certificateVerification != .fullVerification || trustRoots != nil { // add verify block to control certificate verification @@ -195,7 +210,8 @@ extension TLSConfiguration { } } } - }, Self.tlsDispatchQueue + }, + Self.tlsDispatchQueue ) } return options diff --git a/Sources/AsyncHTTPClient/RedirectState.swift b/Sources/AsyncHTTPClient/RedirectState.swift index c4e427ef1..95de2d508 100644 --- a/Sources/AsyncHTTPClient/RedirectState.swift +++ b/Sources/AsyncHTTPClient/RedirectState.swift @@ -12,9 +12,10 @@ // //===----------------------------------------------------------------------===// -import struct Foundation.URL import NIOHTTP1 +import struct Foundation.URL + typealias RedirectMode = HTTPClient.Configuration.RedirectConfiguration.Mode struct RedirectState { diff --git a/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift b/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift index 557af2af1..37b2a42f0 100644 --- a/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift +++ b/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift @@ -12,10 +12,11 @@ // //===----------------------------------------------------------------------===// -import struct Foundation.URL import NIOCore import NIOHTTP1 +import struct Foundation.URL + extension HTTPClient { /// The maximum body size allowed, before a redirect response is cancelled. 3KB. /// @@ -29,11 +30,15 @@ extension HTTPClient { extension RequestBag { struct StateMachine { fileprivate enum State { - case initialized - case queued(HTTPRequestScheduler) + case initialized(RedirectHandler?) + case queued(HTTPRequestScheduler, RedirectHandler?) + /// if the deadline was exceeded while in the `.queued(_:)` state, + /// we wait until the request pool fails the request with a potential more descriptive error message, + /// if a connection failure has occured while the request was queued. + case deadlineExceededWhileQueued case executing(HTTPRequestExecutor, RequestStreamState, ResponseStreamState) case finished(error: Error?) - case redirected(HTTPRequestExecutor, Int, HTTPResponseHead, URL) + case redirected(HTTPRequestExecutor, RedirectHandler, Int, HTTPResponseHead, URL) case modifying } @@ -51,23 +56,22 @@ extension RequestBag { case eof } - case initialized + case initialized(RedirectHandler?) case buffering(CircularBuffer, next: Next) case waitingForRemote } - private var state: State = .initialized - private let redirectHandler: RedirectHandler? + private var state: State init(redirectHandler: RedirectHandler?) { - self.redirectHandler = redirectHandler + self.state = .initialized(redirectHandler) } } } extension RequestBag.StateMachine { mutating func requestWasQueued(_ scheduler: HTTPRequestScheduler) { - guard case .initialized = self.state else { + guard case .initialized(let redirectHandler) = self.state else { // There might be a race between `requestWasQueued` and `willExecuteRequest`: // // If the request is created and passed to the HTTPClient on thread A, it will move into @@ -87,16 +91,26 @@ extension RequestBag.StateMachine { return } - self.state = .queued(scheduler) + self.state = .queued(scheduler, redirectHandler) } - mutating func willExecuteRequest(_ executor: HTTPRequestExecutor) -> Bool { + enum WillExecuteRequestAction { + case cancelExecuter(HTTPRequestExecutor) + case failTaskAndCancelExecutor(Error, HTTPRequestExecutor) + case none + } + + mutating func willExecuteRequest(_ executor: HTTPRequestExecutor) -> WillExecuteRequestAction { switch self.state { - case .initialized, .queued: - self.state = .executing(executor, .initialized, .initialized) - return true + case .initialized(let redirectHandler), .queued(_, let redirectHandler): + self.state = .executing(executor, .initialized, .initialized(redirectHandler)) + return .none + case .deadlineExceededWhileQueued: + let error: Error = HTTPClientError.deadlineExceeded + self.state = .finished(error: error) + return .failTaskAndCancelExecutor(error, executor) case .finished(error: .some): - return false + return .cancelExecuter(executor) case .executing, .redirected, .finished(error: .none), .modifying: preconditionFailure("Invalid state: \(self.state)") } @@ -110,11 +124,11 @@ extension RequestBag.StateMachine { mutating func resumeRequestBodyStream() -> ResumeProducingAction { switch self.state { - case .initialized, .queued: + case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("A request stream can only be resumed, if the request was started") - case .executing(let executor, .initialized, .initialized): - self.state = .executing(executor, .producing, .initialized) + case .executing(let executor, .initialized, .initialized(let redirectHandler)): + self.state = .executing(executor, .producing, .initialized(redirectHandler)) return .startWriter case .executing(_, .producing, _): @@ -150,7 +164,7 @@ extension RequestBag.StateMachine { mutating func pauseRequestBodyStream() { switch self.state { - case .initialized, .queued: + case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("A request stream can only be paused, if the request was started") case .executing(let executor, let requestState, let responseState): switch requestState { @@ -185,7 +199,7 @@ extension RequestBag.StateMachine { mutating func writeNextRequestPart(_ part: IOData, taskEventLoop: EventLoop) -> WriteAction { switch self.state { - case .initialized, .queued: + case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("Invalid state: \(self.state)") case .executing(let executor, let requestState, let responseState): switch requestState { @@ -231,7 +245,7 @@ extension RequestBag.StateMachine { mutating func finishRequestBodyStream(_ result: Result) -> FinishAction { switch self.state { - case .initialized, .queued: + case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("Invalid state: \(self.state)") case .executing(let executor, let requestState, let responseState): switch requestState { @@ -282,27 +296,29 @@ extension RequestBag.StateMachine { /// - Returns: Whether the response should be forwarded to the delegate. Will be `false` if the request follows a redirect. mutating func receiveResponseHead(_ head: HTTPResponseHead) -> ReceiveResponseHeadAction { switch self.state { - case .initialized, .queued: + case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("How can we receive a response, if the request hasn't started yet.") case .executing(let executor, let requestState, let responseState): - guard case .initialized = responseState else { + guard case .initialized(let redirectHandler) = responseState else { preconditionFailure("If we receive a response, we must not have received something else before") } - if let redirectURL = self.redirectHandler?.redirectTarget( - status: head.status, - responseHeaders: head.headers - ) { + if let redirectHandler = redirectHandler, + let redirectURL = redirectHandler.redirectTarget( + status: head.status, + responseHeaders: head.headers + ) + { // If we will redirect, we need to consume the response's body ASAP, to be able to // reuse the existing connection. We will consume a response body, if the body is // smaller than 3kb. switch head.contentLength { case .some(0...(HTTPClient.maxBodySizeRedirectResponse)), .none: - self.state = .redirected(executor, 0, head, redirectURL) + self.state = .redirected(executor, redirectHandler, 0, head, redirectURL) return .signalBodyDemand(executor) case .some: self.state = .finished(error: HTTPClientError.cancelled) - return .redirect(executor, self.redirectHandler!, head, redirectURL) + return .redirect(executor, redirectHandler, head, redirectURL) } } else { self.state = .executing(executor, requestState, .buffering(.init(), next: .askExecutorForMore)) @@ -328,14 +344,16 @@ extension RequestBag.StateMachine { mutating func receiveResponseBodyParts(_ buffer: CircularBuffer) -> ReceiveResponseBodyAction { switch self.state { - case .initialized, .queued: + case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("How can we receive a response body part, if the request hasn't started yet.") case .executing(_, _, .initialized): preconditionFailure("If we receive a response body, we must have received a head before") case .executing(let executor, let requestState, .buffering(var currentBuffer, next: let next)): guard case .askExecutorForMore = next else { - preconditionFailure("If we have received an error or eof before, why did we get another body part? Next: \(next)") + preconditionFailure( + "If we have received an error or eof before, why did we get another body part? Next: \(next)" + ) } self.state = .modifying @@ -347,19 +365,23 @@ extension RequestBag.StateMachine { self.state = .executing(executor, requestState, .buffering(currentBuffer, next: next)) return .none case .executing(let executor, let requestState, .waitingForRemote): - var buffer = buffer - let first = buffer.removeFirst() - self.state = .executing(executor, requestState, .buffering(buffer, next: .askExecutorForMore)) - return .forwardResponsePart(first) - case .redirected(let executor, var receivedBytes, let head, let redirectURL): + if buffer.count > 0 { + var buffer = buffer + let first = buffer.removeFirst() + self.state = .executing(executor, requestState, .buffering(buffer, next: .askExecutorForMore)) + return .forwardResponsePart(first) + } else { + return .none + } + case .redirected(let executor, let redirectHandler, var receivedBytes, let head, let redirectURL): let partsLength = buffer.reduce(into: 0) { $0 += $1.readableBytes } receivedBytes += partsLength if receivedBytes > HTTPClient.maxBodySizeRedirectResponse { self.state = .finished(error: HTTPClientError.cancelled) - return .redirect(executor, self.redirectHandler!, head, redirectURL) + return .redirect(executor, redirectHandler, head, redirectURL) } else { - self.state = .redirected(executor, receivedBytes, head, redirectURL) + self.state = .redirected(executor, redirectHandler, receivedBytes, head, redirectURL) return .signalBodyDemand(executor) } @@ -381,14 +403,16 @@ extension RequestBag.StateMachine { mutating func succeedRequest(_ newChunks: CircularBuffer?) -> ReceiveResponseEndAction { switch self.state { - case .initialized, .queued: + case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("How can we receive a response body part, if the request hasn't started yet.") case .executing(_, _, .initialized): preconditionFailure("If we receive a response body, we must have received a head before") case .executing(let executor, let requestState, .buffering(var buffer, next: let next)): guard case .askExecutorForMore = next else { - preconditionFailure("If we have received an error or eof before, why did we get another body part? Next: \(next)") + preconditionFailure( + "If we have received an error or eof before, why did we get another body part? Next: \(next)" + ) } if buffer.isEmpty, let newChunks = newChunks, !newChunks.isEmpty { @@ -410,9 +434,9 @@ extension RequestBag.StateMachine { self.state = .executing(executor, requestState, .buffering(newChunks, next: .eof)) return .consume(first) - case .redirected(_, _, let head, let redirectURL): + case .redirected(_, let redirectHandler, _, let head, let redirectURL): self.state = .finished(error: nil) - return .redirect(self.redirectHandler!, head, redirectURL) + return .redirect(redirectHandler, head, redirectURL) case .finished(error: .some): return .none @@ -443,10 +467,12 @@ extension RequestBag.StateMachine { private mutating func failWithConsumptionError(_ error: Error) -> ConsumeAction { switch self.state { - case .initialized, .queued: + case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("Invalid state: \(self.state)") case .executing(_, _, .initialized): - preconditionFailure("Invalid state: Must have received response head, before this method is called for the first time") + preconditionFailure( + "Invalid state: Must have received response head, before this method is called for the first time" + ) case .executing(_, _, .buffering(_, next: .error(let connectionError))): // if an error was received from the connection, we fail the task with the one @@ -459,17 +485,23 @@ extension RequestBag.StateMachine { return .failTask(error, executorToCancel: executor) case .executing(_, _, .waitingForRemote): - preconditionFailure("Invalid state... We just returned from a consumption function. We can't already be waiting") + preconditionFailure( + "Invalid state... We just returned from a consumption function. We can't already be waiting" + ) case .redirected: - preconditionFailure("Invalid state... Redirect don't call out to delegate functions. Thus we should never land here.") + preconditionFailure( + "Invalid state... Redirect don't call out to delegate functions. Thus we should never land here." + ) case .finished(error: .some): // don't overwrite existing errors return .doNothing case .finished(error: .none): - preconditionFailure("Invalid state... If no error occured, this must not be called, after the request was finished") + preconditionFailure( + "Invalid state... If no error occured, this must not be called, after the request was finished" + ) case .modifying: preconditionFailure() @@ -478,11 +510,13 @@ extension RequestBag.StateMachine { private mutating func consumeMoreBodyData() -> ConsumeAction { switch self.state { - case .initialized, .queued: + case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("Invalid state: \(self.state)") case .executing(_, _, .initialized): - preconditionFailure("Invalid state: Must have received response head, before this method is called for the first time") + preconditionFailure( + "Invalid state: Must have received response head, before this method is called for the first time" + ) case .executing(let executor, let requestState, .buffering(var buffer, next: .askExecutorForMore)): self.state = .modifying @@ -512,7 +546,9 @@ extension RequestBag.StateMachine { return .failTask(error, executorToCancel: nil) case .executing(_, _, .waitingForRemote): - preconditionFailure("Invalid state... We just returned from a consumption function. We can't already be waiting") + preconditionFailure( + "Invalid state... We just returned from a consumption function. We can't already be waiting" + ) case .redirected: return .doNothing @@ -521,15 +557,42 @@ extension RequestBag.StateMachine { return .doNothing case .finished(error: .none): - preconditionFailure("Invalid state... If no error occurred, this must not be called, after the request was finished") + preconditionFailure( + "Invalid state... If no error occurred, this must not be called, after the request was finished" + ) case .modifying: preconditionFailure() } } + enum DeadlineExceededAction { + case cancelScheduler(HTTPRequestScheduler?) + case fail(FailAction) + } + + mutating func deadlineExceeded() -> DeadlineExceededAction { + switch self.state { + case .queued(let queuer, _): + /// We do not fail the request immediately because we want to give the scheduler a chance of throwing a better error message + /// We therefore depend on the scheduler failing the request after we cancel the request. + self.state = .deadlineExceededWhileQueued + return .cancelScheduler(queuer) + + case .initialized, + .deadlineExceededWhileQueued, + .executing, + .finished, + .redirected, + .modifying: + /// if we are not in the queued state, we can fail early by just calling down to `self.fail(_:)` + /// which does the appropriate state transition for us. + return .fail(self.fail(HTTPClientError.deadlineExceeded)) + } + } + enum FailAction { - case failTask(HTTPRequestScheduler?, HTTPRequestExecutor?) + case failTask(Error, HTTPRequestScheduler?, HTTPRequestExecutor?) case cancelExecutor(HTTPRequestExecutor) case none } @@ -538,31 +601,44 @@ extension RequestBag.StateMachine { switch self.state { case .initialized: self.state = .finished(error: error) - return .failTask(nil, nil) - case .queued(let queuer): + return .failTask(error, nil, nil) + case .queued(let queuer, _): self.state = .finished(error: error) - return .failTask(queuer, nil) + return .failTask(error, queuer, nil) case .executing(let executor, let requestState, .buffering(_, next: .eof)): self.state = .executing(executor, requestState, .buffering(.init(), next: .error(error))) return .cancelExecutor(executor) case .executing(let executor, _, .buffering(_, next: .askExecutorForMore)): self.state = .finished(error: error) - return .failTask(nil, executor) + return .failTask(error, nil, executor) case .executing(let executor, _, .buffering(_, next: .error(_))): // this would override another error, let's keep the first one return .cancelExecutor(executor) case .executing(let executor, _, .initialized): self.state = .finished(error: error) - return .failTask(nil, executor) + return .failTask(error, nil, executor) case .executing(let executor, _, .waitingForRemote): self.state = .finished(error: error) - return .failTask(nil, executor) + return .failTask(error, nil, executor) case .redirected: self.state = .finished(error: error) - return .failTask(nil, nil) + return .failTask(error, nil, nil) case .finished(.none): // An error occurred after the request has finished. Ignore... return .none + case .deadlineExceededWhileQueued: + let realError: Error = { + if (error as? HTTPClientError) == .cancelled { + /// if we just get a `HTTPClientError.cancelled` we can use the original cancellation reason + /// to give a more descriptive error to the user. + return HTTPClientError.deadlineExceeded + } else { + /// otherwise we already had an intermediate connection error which we should present to the user instead + return error + } + }() + self.state = .finished(error: realError) + return .failTask(realError, nil, nil) case .finished(.some(_)): // this might happen, if the stream consumer has failed... let's just drop the data return .none diff --git a/Sources/AsyncHTTPClient/RequestBag+Tracing.swift b/Sources/AsyncHTTPClient/RequestBag+Tracing.swift new file mode 100644 index 000000000..729b6256a --- /dev/null +++ b/Sources/AsyncHTTPClient/RequestBag+Tracing.swift @@ -0,0 +1,73 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Logging +import NIOConcurrencyHelpers +import NIOCore +import NIOHTTP1 +import NIOSSL +import Tracing + +extension RequestBag.LoopBoundState { + + /// Starts the "overall" Span that encompases the beginning of a request until receipt of the head part of the response. + mutating func startRequestSpan(tracer: T?) { + guard #available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *), + let tracer = tracer as! (any Tracer)? + else { + return + } + + assert( + self.activeSpan == nil, + "Unexpected active span when starting new request span! Was: \(String(describing: self.activeSpan))" + ) + self.activeSpan = tracer.startSpan("\(request.method)", ofKind: .client) + } + + /// Fails the active overall span given some internal error, e.g. timeout, pool shutdown etc. + /// This is not to be used for failing a span given a failure status coded HTTPResponse. + mutating func failRequestSpanAsCancelled() { + if #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) { + failRequestSpan(error: CancellationError()) + } else { + failRequestSpan(error: HTTPRequestCancellationError()) + } + } + + mutating func failRequestSpan(error: any Error) { + guard let span = activeSpan else { + return + } + + span.recordError(error) + span.end() + + self.activeSpan = nil + } + + /// Ends the active overall span upon receipt of the response head. + /// + /// If the status code is in error range, this will automatically fail the span. + mutating func endRequestSpan(response: HTTPResponseHead) { + guard let span = activeSpan else { + return + } + + TracingSupport.handleResponseStatusCode(span, response.status, keys: tracing.attributeKeys) + + span.end() + self.activeSpan = nil + } +} diff --git a/Sources/AsyncHTTPClient/RequestBag.swift b/Sources/AsyncHTTPClient/RequestBag.swift index b4aeef0e7..ff3ed8442 100644 --- a/Sources/AsyncHTTPClient/RequestBag.swift +++ b/Sources/AsyncHTTPClient/RequestBag.swift @@ -17,18 +17,48 @@ import NIOConcurrencyHelpers import NIOCore import NIOHTTP1 import NIOSSL +import Tracing + +@preconcurrency +final class RequestBag: Sendable { + /// Defends against the call stack getting too large when consuming body parts. + /// + /// If the response body comes in lots of tiny chunks, we'll deliver those tiny chunks to users + /// one at a time. + private static var maxConsumeBodyPartStackDepth: Int { + 50 + } + + let poolKey: ConnectionPool.Key -final class RequestBag { let task: HTTPClient.Task var eventLoop: EventLoop { self.task.eventLoop } private let delegate: Delegate - private let request: HTTPClient.Request - // the request state is synchronized on the task eventLoop - private var state: StateMachine + struct LoopBoundState: @unchecked Sendable { + // The 'StateMachine' *isn't* Sendable (it holds various objects which aren't). This type + // needs to be sendable so that we can construct a loop bound box off of the event loop + // to hold this state and then subsequently only access it from the event loop. This needs + // to happen so that the request bag can be constructed off of the event loop. If it's + // constructed on the event loop then there's a timing window between users issuing + // a request and calling shutdown where the underlying pool doesn't know about the request + // so the shutdown call may cancel it. + var request: HTTPClient.Request + var state: StateMachine + var consumeBodyPartStackDepth: Int + // if a redirect occurs, we store the task for it so we can propagate cancellation + var redirectTask: HTTPClient.Task? = nil + + // - Distributed tracing + var tracing: HTTPClient.TracingConfiguration + // The current span, representing the entire request/response made by an execute call. + var activeSpan: (any Span)? = nil + } + + private let loopBoundState: NIOLoopBoundBox // MARK: HTTPClientTask properties @@ -36,6 +66,16 @@ final class RequestBag { self.task.logger } + // Available unconditionally, so we can simplify callsites which can just try to pass this value + // regardless if the real tracer exists or not. + var anyTracer: (any Sendable)? { + self.task.tracing._tracer + } + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + var tracer: (any Tracer)? { + self.task.tracer + } + let connectionDeadline: NIODeadline let requestOptions: RequestOptions @@ -45,17 +85,28 @@ final class RequestBag { let eventLoopPreference: HTTPClient.EventLoopPreference - init(request: HTTPClient.Request, - eventLoopPreference: HTTPClient.EventLoopPreference, - task: HTTPClient.Task, - redirectHandler: RedirectHandler?, - connectionDeadline: NIODeadline, - requestOptions: RequestOptions, - delegate: Delegate) throws { + let tlsConfiguration: TLSConfiguration? + + init( + request: HTTPClient.Request, + eventLoopPreference: HTTPClient.EventLoopPreference, + task: HTTPClient.Task, + redirectHandler: RedirectHandler?, + connectionDeadline: NIODeadline, + requestOptions: RequestOptions, + delegate: Delegate + ) throws { + self.poolKey = .init(request, dnsOverride: requestOptions.dnsOverride) self.eventLoopPreference = eventLoopPreference self.task = task - self.state = .init(redirectHandler: redirectHandler) - self.request = request + + let loopBoundState = LoopBoundState( + request: request, + state: StateMachine(redirectHandler: redirectHandler), + consumeBodyPartStackDepth: 0, + tracing: task.tracing + ) + self.loopBoundState = NIOLoopBoundBox.makeBoxSendingValue(loopBoundState, eventLoop: task.eventLoop) self.connectionDeadline = connectionDeadline self.requestOptions = requestOptions self.delegate = delegate @@ -64,6 +115,8 @@ final class RequestBag { self.requestHead = head self.requestFramingMetadata = metadata + self.tlsConfiguration = request.tlsConfiguration + self.task.taskDelegate = self self.task.futureResult.whenComplete { _ in self.task.taskDelegate = nil @@ -72,40 +125,47 @@ final class RequestBag { private func requestWasQueued0(_ scheduler: HTTPRequestScheduler) { self.logger.debug("Request was queued (waiting for a connection to become available)") - - self.task.eventLoop.assertInEventLoop() - self.state.requestWasQueued(scheduler) + self.loopBoundState.value.state.requestWasQueued(scheduler) } // MARK: - Request - private func willExecuteRequest0(_ executor: HTTPRequestExecutor) { - self.task.eventLoop.assertInEventLoop() - if !self.state.willExecuteRequest(executor) { - return executor.cancelRequest(self) + // Immediately start a span for the "whole" request + self.loopBoundState.value.startRequestSpan(tracer: self.anyTracer) + + let action = self.loopBoundState.value.state.willExecuteRequest(executor) + switch action { + case .cancelExecuter(let executor): + executor.cancelRequest(self) + self.loopBoundState.value.failRequestSpanAsCancelled() + case .failTaskAndCancelExecutor(let error, let executor): + self.delegate.didReceiveError(task: self.task, error) + self.task.failInternal(with: error) + executor.cancelRequest(self) + self.loopBoundState.value.failRequestSpan(error: error) + case .none: + break } } private func requestHeadSent0() { - self.task.eventLoop.assertInEventLoop() - self.delegate.didSendRequestHead(task: self.task, self.requestHead) - if self.request.body == nil { + if self.loopBoundState.value.request.body == nil { self.delegate.didSendRequest(task: self.task) } } private func resumeRequestBodyStream0() { - self.task.eventLoop.assertInEventLoop() - - let produceAction = self.state.resumeRequestBodyStream() + let produceAction = self.loopBoundState.value.state.resumeRequestBodyStream() switch produceAction { case .startWriter: - guard let body = self.request.body else { + guard let body = self.loopBoundState.value.request.body else { preconditionFailure("Expected to have a body, if the `HTTPRequestStateMachine` resume a request stream") } + self.loopBoundState.value.request.body = nil let writer = HTTPClient.Body.StreamWriter { self.writeNextRequestPart($0) @@ -124,9 +184,7 @@ final class RequestBag { } private func pauseRequestBodyStream0() { - self.task.eventLoop.assertInEventLoop() - - self.state.pauseRequestBodyStream() + self.loopBoundState.value.state.pauseRequestBodyStream() } private func writeNextRequestPart(_ part: IOData) -> EventLoopFuture { @@ -140,39 +198,39 @@ final class RequestBag { } private func writeNextRequestPart0(_ part: IOData) -> EventLoopFuture { - self.eventLoop.assertInEventLoop() - - let action = self.state.writeNextRequestPart(part, taskEventLoop: self.task.eventLoop) + let action = self.loopBoundState.value.state.writeNextRequestPart(part, taskEventLoop: self.task.eventLoop) switch action { case .failTask(let error): self.delegate.didReceiveError(task: self.task, error) - self.task.fail(with: error, delegateType: Delegate.self) + self.task.failInternal(with: error) return self.task.eventLoop.makeFailedFuture(error) case .failFuture(let error): return self.task.eventLoop.makeFailedFuture(error) case .write(let part, let writer, let future): - writer.writeRequestBodyPart(part, request: self) - self.delegate.didSendRequestPart(task: self.task, part) + let promise = self.task.eventLoop.makePromise(of: Void.self) + promise.futureResult.whenSuccess { + self.delegate.didSendRequestPart(task: self.task, part) + } + writer.writeRequestBodyPart(part, request: self, promise: promise) return future } } private func finishRequestBodyStream(_ result: Result) { - self.task.eventLoop.assertInEventLoop() - - let action = self.state.finishRequestBodyStream(result) + let action = self.loopBoundState.value.state.finishRequestBodyStream(result) switch action { case .none: break - case .forwardStreamFinished(let writer, let promise): - writer.finishRequestBodyStream(self) - promise?.succeed(()) - - self.delegate.didSendRequest(task: self.task) + case .forwardStreamFinished(let writer, let writerPromise): + let promise = writerPromise ?? self.task.eventLoop.makePromise(of: Void.self) + promise.futureResult.whenSuccess { + self.delegate.didSendRequest(task: self.task) + } + writer.finishRequestBodyStream(self, promise: promise) case .forwardStreamFailureAndFailTask(let writer, let error, let promise): writer.cancelRequest(self) @@ -193,10 +251,11 @@ final class RequestBag { // MARK: - Response - private func receiveResponseHead0(_ head: HTTPResponseHead) { - self.task.eventLoop.assertInEventLoop() + self.delegate.didVisitURL(task: self.task, self.loopBoundState.value.request, head) + self.loopBoundState.value.endRequestSpan(response: head) // runs most likely on channel eventLoop - switch self.state.receiveResponseHead(head) { + switch self.loopBoundState.value.state.receiveResponseHead(head) { case .none: break @@ -204,7 +263,11 @@ final class RequestBag { executor.demandResponseBodyStream(self) case .redirect(let executor, let handler, let head, let newURL): - handler.redirect(status: head.status, to: newURL, promise: self.task.promise) + self.loopBoundState.value.redirectTask = handler.redirect( + status: head.status, + to: newURL, + promise: self.task.promise + ) executor.cancelRequest(self) case .forwardResponseHead(let head): @@ -218,9 +281,7 @@ final class RequestBag { } private func receiveResponseBodyParts0(_ buffer: CircularBuffer) { - self.task.eventLoop.assertInEventLoop() - - switch self.state.receiveResponseBodyParts(buffer) { + switch self.loopBoundState.value.state.receiveResponseBodyParts(buffer) { case .none: break @@ -228,7 +289,11 @@ final class RequestBag { executor.demandResponseBodyStream(self) case .redirect(let executor, let handler, let head, let newURL): - handler.redirect(status: head.status, to: newURL, promise: self.task.promise) + self.loopBoundState.value.redirectTask = handler.redirect( + status: head.status, + to: newURL, + promise: self.task.promise + ) executor.cancelRequest(self) case .forwardResponsePart(let part): @@ -242,8 +307,7 @@ final class RequestBag { } private func succeedRequest0(_ buffer: CircularBuffer?) { - self.task.eventLoop.assertInEventLoop() - let action = self.state.succeedRequest(buffer) + let action = self.loopBoundState.value.state.succeedRequest(buffer) switch action { case .none: @@ -252,14 +316,7 @@ final class RequestBag { self.delegate.didReceiveBodyPart(task: self.task, buffer) .hop(to: self.task.eventLoop) .whenComplete { - switch $0 { - case .success: - self.consumeMoreBodyData0(resultOfPreviousConsume: $0) - case .failure(let error): - // if in the response stream consumption an error has occurred, we need to - // cancel the running request and fail the task. - self.fail(error) - } + self.consumeMoreBodyData0(resultOfPreviousConsume: $0) } case .succeedRequest: @@ -271,25 +328,48 @@ final class RequestBag { } case .redirect(let handler, let head, let newURL): - handler.redirect(status: head.status, to: newURL, promise: self.task.promise) + self.loopBoundState.value.redirectTask = handler.redirect( + status: head.status, + to: newURL, + promise: self.task.promise + ) } } private func consumeMoreBodyData0(resultOfPreviousConsume result: Result) { - self.task.eventLoop.assertInEventLoop() + // We get defensive here about the maximum stack depth. It's possible for the `didReceiveBodyPart` + // future to be returned to us completed. If it is, we will recurse back into this method. To + // break that recursion we have a max stack depth which we increment and decrement in this method: + // if it gets too large, instead of recurring we'll insert an `eventLoop.execute`, which will + // manually break the recursion and unwind the stack. + // + // Note that we don't bother starting this at the various other call sites that _begin_ stacks + // that risk ending up in this loop. That's because we don't need an accurate count: our limit is + // a best-effort target anyway, one stack frame here or there does not put us at risk. We're just + // trying to prevent ourselves looping out of control. + self.loopBoundState.value.consumeBodyPartStackDepth += 1 + defer { + self.loopBoundState.value.consumeBodyPartStackDepth -= 1 + assert(self.loopBoundState.value.consumeBodyPartStackDepth >= 0) + } - let consumptionAction = self.state.consumeMoreBodyData(resultOfPreviousConsume: result) + let consumptionAction = self.loopBoundState.value.state.consumeMoreBodyData( + resultOfPreviousConsume: result + ) switch consumptionAction { case .consume(let byteBuffer): self.delegate.didReceiveBodyPart(task: self.task, byteBuffer) .hop(to: self.task.eventLoop) - .whenComplete { - switch $0 { - case .success: - self.consumeMoreBodyData0(resultOfPreviousConsume: $0) - case .failure(let error): - self.fail(error) + .assumeIsolated() + .whenComplete { result in + if self.loopBoundState.value.consumeBodyPartStackDepth < Self.maxConsumeBodyPartStackDepth { + self.consumeMoreBodyData0(resultOfPreviousConsume: result) + } else { + // We need to unwind the stack, let's take a break. + self.task.eventLoop.assumeIsolated().execute { + self.consumeMoreBodyData0(resultOfPreviousConsume: result) + } } } @@ -298,7 +378,7 @@ final class RequestBag { case .finishStream: do { let response = try self.delegate.didFinishRequest(task: self.task) - self.task.promise.succeed(response) + self.task.promise.assumeIsolated().succeed(response) } catch { self.task.promise.fail(error) } @@ -312,12 +392,16 @@ final class RequestBag { } private func fail0(_ error: Error) { - self.task.eventLoop.assertInEventLoop() + let action = self.loopBoundState.value.state.fail(error) - let action = self.state.fail(error) + self.executeFailAction0(action) + + self.loopBoundState.value.redirectTask?.fail(reason: error) + } + private func executeFailAction0(_ action: RequestBag.StateMachine.FailAction) { switch action { - case .failTask(let scheduler, let executor): + case .failTask(let error, let scheduler, let executor): scheduler?.cancelRequest(self) executor?.cancelRequest(self) self.failTask0(error) @@ -327,16 +411,30 @@ final class RequestBag { break } } -} -extension RequestBag: HTTPSchedulableRequest { - var poolKey: ConnectionPool.Key { - ConnectionPool.Key(self.request) + func deadlineExceeded0() { + let action = self.loopBoundState.value.state.deadlineExceeded() + + switch action { + case .cancelScheduler(let scheduler): + scheduler?.cancelRequest(self) + case .fail(let failAction): + self.executeFailAction0(failAction) + } } - var tlsConfiguration: TLSConfiguration? { - self.request.tlsConfiguration + func deadlineExceeded() { + if self.task.eventLoop.inEventLoop { + self.deadlineExceeded0() + } else { + self.task.eventLoop.execute { + self.deadlineExceeded0() + } + } } +} + +extension RequestBag: HTTPSchedulableRequest, HTTPClientTaskDelegate { func requestWasQueued(_ scheduler: HTTPRequestScheduler) { if self.task.eventLoop.inEventLoop { @@ -374,8 +472,8 @@ extension RequestBag: HTTPExecutableRequest { case .indifferent: return self.task.eventLoop case .delegate(let eventLoop), - .delegateAndChannel(on: let eventLoop), - .testOnly_exact(channelOn: let eventLoop, delegateOn: _): + .delegateAndChannel(on: let eventLoop), + .testOnly_exact(channelOn: let eventLoop, delegateOn: _): return eventLoop } } @@ -450,15 +548,3 @@ extension RequestBag: HTTPExecutableRequest { } } } - -extension RequestBag: HTTPClientTaskDelegate { - func cancel() { - if self.task.eventLoop.inEventLoop { - self.fail0(HTTPClientError.cancelled) - } else { - self.task.eventLoop.execute { - self.fail0(HTTPClientError.cancelled) - } - } - } -} diff --git a/Sources/AsyncHTTPClient/RequestValidation.swift b/Sources/AsyncHTTPClient/RequestValidation.swift index e23c35423..f338e06a9 100644 --- a/Sources/AsyncHTTPClient/RequestValidation.swift +++ b/Sources/AsyncHTTPClient/RequestValidation.swift @@ -21,6 +21,7 @@ extension HTTPHeaders { bodyLength: RequestBodyLength ) throws -> RequestFramingMetadata { try self.validateFieldNames() + try self.validateFieldValues() if case .TRACE = method { switch bodyLength { @@ -49,23 +50,23 @@ extension HTTPHeaders { let satisfy = name.utf8.allSatisfy { char -> Bool in switch char { case UInt8(ascii: "a")...UInt8(ascii: "z"), - UInt8(ascii: "A")...UInt8(ascii: "Z"), - UInt8(ascii: "0")...UInt8(ascii: "9"), - UInt8(ascii: "!"), - UInt8(ascii: "#"), - UInt8(ascii: "$"), - UInt8(ascii: "%"), - UInt8(ascii: "&"), - UInt8(ascii: "'"), - UInt8(ascii: "*"), - UInt8(ascii: "+"), - UInt8(ascii: "-"), - UInt8(ascii: "."), - UInt8(ascii: "^"), - UInt8(ascii: "_"), - UInt8(ascii: "`"), - UInt8(ascii: "|"), - UInt8(ascii: "~"): + UInt8(ascii: "A")...UInt8(ascii: "Z"), + UInt8(ascii: "0")...UInt8(ascii: "9"), + UInt8(ascii: "!"), + UInt8(ascii: "#"), + UInt8(ascii: "$"), + UInt8(ascii: "%"), + UInt8(ascii: "&"), + UInt8(ascii: "'"), + UInt8(ascii: "*"), + UInt8(ascii: "+"), + UInt8(ascii: "-"), + UInt8(ascii: "."), + UInt8(ascii: "^"), + UInt8(ascii: "_"), + UInt8(ascii: "`"), + UInt8(ascii: "|"), + UInt8(ascii: "~"): return true default: return false @@ -80,6 +81,56 @@ extension HTTPHeaders { } } + private func validateFieldValues() throws { + let invalidValues = self.compactMap { _, value -> String? in + let satisfy = value.utf8.allSatisfy { char -> Bool in + /// Validates a byte of a given header field value against the definition in RFC 9110. + /// + /// The spec in [RFC 9110](https://httpwg.org/specs/rfc9110.html#fields.values) defines the valid + /// characters as the following: + /// + /// ``` + /// field-value = *field-content + /// field-content = field-vchar + /// [ 1*( SP / HTAB / field-vchar ) field-vchar ] + /// field-vchar = VCHAR / obs-text + /// obs-text = %x80-FF + /// ``` + /// + /// Additionally, it makes the following note: + /// + /// "Field values containing CR, LF, or NUL characters are invalid and dangerous, due to the + /// varying ways that implementations might parse and interpret those characters; a recipient + /// of CR, LF, or NUL within a field value MUST either reject the message or replace each of + /// those characters with SP before further processing or forwarding of that message. Field + /// values containing other CTL characters are also invalid; however, recipients MAY retain + /// such characters for the sake of robustness when they appear within a safe context (e.g., + /// an application-specific quoted string that will not be processed by any downstream HTTP + /// parser)." + /// + /// As we cannot guarantee the context is safe, this code will reject all ASCII control characters + /// directly _except_ for HTAB, which is explicitly allowed. + switch char { + case UInt8(ascii: "\t"): + // HTAB, explicitly allowed. + return true + case 0...0x1f, 0x7F: + // ASCII control character, forbidden. + return false + default: + // Printable or non-ASCII, allowed. + return true + } + } + + return satisfy ? nil : value + } + + guard invalidValues.count == 0 else { + throw HTTPClientError.invalidHeaderFieldValues(invalidValues) + } + } + private mutating func setTransportFraming( method: HTTPMethod, bodyLength: RequestBodyLength @@ -115,13 +166,14 @@ extension HTTPHeaders { mutating func addHostIfNeeded(for url: DeconstructedURL) { // if no host header was set, let's use the url host guard !self.contains(name: "host"), - var host = url.connectionTarget.host + var host = url.connectionTarget.host else { return } // if the request uses a non-default port, we need to add it after the host if let port = url.connectionTarget.port, - port != url.scheme.defaultPort { + port != url.scheme.defaultPort + { host += ":\(port)" } self.add(name: "host", value: host) diff --git a/Sources/AsyncHTTPClient/SSLContextCache.swift b/Sources/AsyncHTTPClient/SSLContextCache.swift index 31ed106a0..599003e56 100644 --- a/Sources/AsyncHTTPClient/SSLContextCache.swift +++ b/Sources/AsyncHTTPClient/SSLContextCache.swift @@ -18,40 +18,50 @@ import NIOConcurrencyHelpers import NIOCore import NIOSSL -class SSLContextCache { - private let lock = Lock() +final class SSLContextCache { + private let lock = NIOLock() private var sslContextCache = LRUCache() private let offloadQueue = DispatchQueue(label: "io.github.swift-server.AsyncHTTPClient.SSLContextCache") } extension SSLContextCache { - func sslContext(tlsConfiguration: TLSConfiguration, - eventLoop: EventLoop, - logger: Logger) -> EventLoopFuture { + func sslContext( + tlsConfiguration: TLSConfiguration, + eventLoop: EventLoop, + logger: Logger + ) -> EventLoopFuture { let eqTLSConfiguration = BestEffortHashableTLSConfiguration(wrapping: tlsConfiguration) let sslContext = self.lock.withLock { self.sslContextCache.find(key: eqTLSConfiguration) } if let sslContext = sslContext { - logger.trace("found SSL context in cache", - metadata: ["ahc-tls-config": "\(tlsConfiguration)"]) + logger.trace( + "found SSL context in cache", + metadata: ["ahc-tls-config": "\(tlsConfiguration)"] + ) return eventLoop.makeSucceededFuture(sslContext) } - logger.trace("creating new SSL context", - metadata: ["ahc-tls-config": "\(tlsConfiguration)"]) + logger.trace( + "creating new SSL context", + metadata: ["ahc-tls-config": "\(tlsConfiguration)"] + ) let newSSLContext = self.offloadQueue.asyncWithFuture(eventLoop: eventLoop) { try NIOSSLContext(configuration: tlsConfiguration) } newSSLContext.whenSuccess { (newSSLContext: NIOSSLContext) -> Void in self.lock.withLock { () -> Void in - self.sslContextCache.append(key: eqTLSConfiguration, - value: newSSLContext) + self.sslContextCache.append( + key: eqTLSConfiguration, + value: newSSLContext + ) } } return newSSLContext } } + +extension SSLContextCache: @unchecked Sendable {} diff --git a/Sources/AsyncHTTPClient/Singleton.swift b/Sources/AsyncHTTPClient/Singleton.swift new file mode 100644 index 000000000..0ddf1bc40 --- /dev/null +++ b/Sources/AsyncHTTPClient/Singleton.swift @@ -0,0 +1,35 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2023 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +extension HTTPClient { + /// A globally shared, singleton ``HTTPClient``. + /// + /// The returned client uses the following settings: + /// - configuration is ``HTTPClient/Configuration/singletonConfiguration`` (matching the platform's default/prevalent browser as well as possible) + /// - `EventLoopGroup` is ``HTTPClient/defaultEventLoopGroup`` (matching the platform default) + /// - logging is disabled + public static var shared: HTTPClient { + globallySharedHTTPClient + } +} + +private let globallySharedHTTPClient: HTTPClient = { + let httpClient = HTTPClient( + eventLoopGroup: HTTPClient.defaultEventLoopGroup, + configuration: .singletonConfiguration, + backgroundActivityLogger: HTTPClient.loggingDisabled, + canBeShutDown: false + ) + return httpClient +}() diff --git a/Sources/AsyncHTTPClient/StringConvertibleInstances.swift b/Sources/AsyncHTTPClient/StringConvertibleInstances.swift index f75fb0d87..61d4b067a 100644 --- a/Sources/AsyncHTTPClient/StringConvertibleInstances.swift +++ b/Sources/AsyncHTTPClient/StringConvertibleInstances.swift @@ -14,6 +14,6 @@ extension HTTPClient.EventLoopPreference: CustomStringConvertible { public var description: String { - return "\(self.preference)" + "\(self.preference)" } } diff --git a/Sources/AsyncHTTPClient/StructuredConcurrencyHelpers.swift b/Sources/AsyncHTTPClient/StructuredConcurrencyHelpers.swift new file mode 100644 index 000000000..071e93d36 --- /dev/null +++ b/Sources/AsyncHTTPClient/StructuredConcurrencyHelpers.swift @@ -0,0 +1,49 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2025 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +// swift-format-ignore +// Note: Whitespace changes are used to workaround compiler bug +// https://github.com/swiftlang/swift/issues/79285 + +@inlinable +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +internal func asyncDo( + isolation: isolated (any Actor)? = #isolation, + _ body: () async throws -> sending R, + finally: sending @escaping ((any Error)?) async throws -> Void +) async throws -> sending R { + let result: R + do { + result = try await body() + } catch { + // `body` failed, we need to invoke `finally` with the `error`. + + // This _looks_ unstructured but isn't really because we unconditionally always await the return. + // We need to have an uncancelled task here to assure this is actually running in case we hit a + // cancellation error. + try await Task { + try await finally(error) + }.value + throw error + } + + // `body` succeeded, we need to invoke `finally` with `nil` (no error). + + // This _looks_ unstructured but isn't really because we unconditionally always await the return. + // We need to have an uncancelled task here to assure this is actually running in case we hit a + // cancellation error. + try await Task { + try await finally(nil) + }.value + return result +} diff --git a/Sources/AsyncHTTPClient/TracingSupport.swift b/Sources/AsyncHTTPClient/TracingSupport.swift new file mode 100644 index 000000000..feb564ffb --- /dev/null +++ b/Sources/AsyncHTTPClient/TracingSupport.swift @@ -0,0 +1,53 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Logging +import NIOConcurrencyHelpers +import NIOCore +import NIOHTTP1 +import NIOSSL +import Tracing + +// MARK: - Centralized span attribute handling + +@usableFromInline +struct TracingSupport { + @inlinable + static func handleResponseStatusCode( + _ span: Span, + _ status: HTTPResponseStatus, + keys: HTTPClient.TracingConfiguration.AttributeKeys + ) { + if status.code >= 400 { + span.setStatus(.init(code: .error)) + } + span.attributes[keys.responseStatusCode] = SpanAttribute.int64(Int64(status.code)) + } +} + +// MARK: - HTTPHeadersInjector + +struct HTTPHeadersInjector: Injector, @unchecked Sendable { + static let shared: HTTPHeadersInjector = HTTPHeadersInjector() + + private init() {} + + func inject(_ value: String, forKey name: String, into headers: inout HTTPHeaders) { + headers.add(name: name, value: value) + } +} + +// MARK: - Errors + +internal struct HTTPRequestCancellationError: Error {} diff --git a/Sources/AsyncHTTPClient/Utils.swift b/Sources/AsyncHTTPClient/Utils.swift index f4154df3d..985755143 100644 --- a/Sources/AsyncHTTPClient/Utils.swift +++ b/Sources/AsyncHTTPClient/Utils.swift @@ -14,21 +14,26 @@ import NIOCore -public final class HTTPClientCopyingDelegate: HTTPClientResponseDelegate { +/// An ``HTTPClientResponseDelegate`` that wraps a callback. +/// +/// ``HTTPClientCopyingDelegate`` discards most parts of a HTTP response, but streams the body +/// to the `chunkHandler` provided on ``init(chunkHandler:)``. This is mostly useful for testing. +public final class HTTPClientCopyingDelegate: HTTPClientResponseDelegate, Sendable { public typealias Response = Void - let chunkHandler: (ByteBuffer) -> EventLoopFuture + let chunkHandler: @Sendable (ByteBuffer) -> EventLoopFuture - public init(chunkHandler: @escaping (ByteBuffer) -> EventLoopFuture) { + @preconcurrency + public init(chunkHandler: @Sendable @escaping (ByteBuffer) -> EventLoopFuture) { self.chunkHandler = chunkHandler } public func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { - return self.chunkHandler(buffer) + self.chunkHandler(buffer) } public func didFinishRequest(task: HTTPClient.Task) throws { - return () + () } } @@ -39,7 +44,12 @@ public final class HTTPClientCopyingDelegate: HTTPClientResponseDelegate { /// https://forums.swift.org/t/support-debug-only-code/11037 for a discussion. @inlinable internal func debugOnly(_ body: () -> Void) { - assert({ body(); return true }()) + assert( + { + body() + return true + }() + ) } extension BidirectionalCollection where Element: Equatable { @@ -56,8 +66,8 @@ extension BidirectionalCollection where Element: Equatable { guard self[ourIdx] == suffix[suffixIdx] else { return false } } guard suffixIdx == suffix.startIndex else { - return false // Exhausted self, but 'suffix' has elements remaining. + return false // Exhausted self, but 'suffix' has elements remaining. } - return true // Exhausted 'other' without finding a mismatch. + return true // Exhausted 'other' without finding a mismatch. } } diff --git a/Sources/CAsyncHTTPClient/CAsyncHTTPClient.c b/Sources/CAsyncHTTPClient/CAsyncHTTPClient.c index 2a09d04c9..6342da89f 100644 --- a/Sources/CAsyncHTTPClient/CAsyncHTTPClient.c +++ b/Sources/CAsyncHTTPClient/CAsyncHTTPClient.c @@ -15,7 +15,6 @@ #if __APPLE__ #include #elif __linux__ - #define _GNU_SOURCE #include #endif @@ -32,7 +31,11 @@ bool swiftahc_cshims_strptime(const char * string, const char * format, struct t bool swiftahc_cshims_strptime_l(const char * string, const char * format, struct tm * result, void * locale) { // The pointer cast is fine as long we make sure it really points to a locale_t. +#if defined(__musl__) || defined(__ANDROID__) + const char * firstNonProcessed = strptime(string, format, result); +#else const char * firstNonProcessed = strptime_l(string, format, result, (locale_t)locale); +#endif if (firstNonProcessed) { return *firstNonProcessed == 0; } diff --git a/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests+XCTest.swift b/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests+XCTest.swift deleted file mode 100644 index 6a8d923c7..000000000 --- a/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests+XCTest.swift +++ /dev/null @@ -1,46 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// AsyncAwaitEndToEndTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension AsyncAwaitEndToEndTests { - static var allTests: [(String, (AsyncAwaitEndToEndTests) -> () throws -> Void)] { - return [ - ("testSimpleGet", testSimpleGet), - ("testSimplePost", testSimplePost), - ("testPostWithByteBuffer", testPostWithByteBuffer), - ("testPostWithSequenceOfUInt8", testPostWithSequenceOfUInt8), - ("testPostWithCollectionOfUInt8", testPostWithCollectionOfUInt8), - ("testPostWithRandomAccessCollectionOfUInt8", testPostWithRandomAccessCollectionOfUInt8), - ("testPostWithAsyncSequenceOfByteBuffers", testPostWithAsyncSequenceOfByteBuffers), - ("testPostWithAsyncSequenceOfUInt8", testPostWithAsyncSequenceOfUInt8), - ("testPostWithFragmentedAsyncSequenceOfByteBuffers", testPostWithFragmentedAsyncSequenceOfByteBuffers), - ("testPostWithFragmentedAsyncSequenceOfLargeByteBuffers", testPostWithFragmentedAsyncSequenceOfLargeByteBuffers), - ("testCanceling", testCanceling), - ("testDeadline", testDeadline), - ("testImmediateDeadline", testImmediateDeadline), - ("testInvalidURL", testInvalidURL), - ("testRedirectChangesHostHeader", testRedirectChangesHostHeader), - ("testShutdown", testShutdown), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift b/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift index 2cd056225..56a08b852 100644 --- a/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift +++ b/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift @@ -12,14 +12,18 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOCore +import NIOFoundationCompat +import NIOHTTP1 import NIOPosix +import NIOSSL import XCTest +@testable import AsyncHTTPClient + private func makeDefaultHTTPClient( - eventLoopGroupProvider: HTTPClient.EventLoopGroupProvider = .createNew + eventLoopGroupProvider: HTTPClient.EventLoopGroupProvider = .singleton ) -> HTTPClient { var config = HTTPClient.Configuration() config.tlsConfiguration = .clientDefault @@ -32,10 +36,30 @@ private func makeDefaultHTTPClient( ) } +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) final class AsyncAwaitEndToEndTests: XCTestCase { + var clientGroup: EventLoopGroup! + var serverGroup: EventLoopGroup! + + override func setUp() { + XCTAssertNil(self.clientGroup) + XCTAssertNil(self.serverGroup) + + self.clientGroup = getDefaultEventLoopGroup(numberOfThreads: 1) + self.serverGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + } + + override func tearDown() { + XCTAssertNotNil(self.clientGroup) + XCTAssertNoThrow(try self.clientGroup.syncShutdownGracefully()) + self.clientGroup = nil + + XCTAssertNotNil(self.serverGroup) + XCTAssertNoThrow(try self.serverGroup.syncShutdownGracefully()) + self.serverGroup = nil + } + func testSimpleGet() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { let bin = HTTPBin(.http2(compress: false)) defer { XCTAssertNoThrow(try bin.shutdown()) } @@ -44,21 +68,22 @@ final class AsyncAwaitEndToEndTests: XCTestCase { let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) let request = HTTPClientRequest(url: "https://localhost:\(bin.port)/get") - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } + XCTAssertEqual(response.url?.absoluteString, request.url) + XCTAssertEqual(response.history.map(\.request.url), [request.url]) XCTAssertEqual(response.status, .ok) XCTAssertEqual(response.version, .http2) } - #endif } func testSimplePost() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { let bin = HTTPBin(.http2(compress: false)) defer { XCTAssertNoThrow(try bin.shutdown()) } @@ -67,21 +92,22 @@ final class AsyncAwaitEndToEndTests: XCTestCase { let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) let request = HTTPClientRequest(url: "https://localhost:\(bin.port)/get") - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } + XCTAssertEqual(response.url?.absoluteString, request.url) + XCTAssertEqual(response.history.map(\.request.url), [request.url]) XCTAssertEqual(response.status, .ok) XCTAssertEqual(response.version, .http2) } - #endif } func testPostWithByteBuffer() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { let bin = HTTPBin(.http2(compress: false)) { _ in HTTPEchoHandler() } defer { XCTAssertNoThrow(try bin.shutdown()) } @@ -92,21 +118,22 @@ final class AsyncAwaitEndToEndTests: XCTestCase { request.method = .POST request.body = .bytes(ByteBuffer(string: "1234")) - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { return } + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } XCTAssertEqual(response.headers["content-length"], ["4"]) - guard let body = await XCTAssertNoThrowWithResult( - try await response.body.collect() - ) else { return } + guard + let body = await XCTAssertNoThrowWithResult( + try await response.body.collect(upTo: 1024) + ) + else { return } XCTAssertEqual(body, ByteBuffer(string: "1234")) } - #endif } func testPostWithSequenceOfUInt8() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { let bin = HTTPBin(.http2(compress: false)) { _ in HTTPEchoHandler() } defer { XCTAssertNoThrow(try bin.shutdown()) } @@ -115,23 +142,24 @@ final class AsyncAwaitEndToEndTests: XCTestCase { let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/") request.method = .POST - request.body = .bytes(AnySequence("1234".utf8), length: .unknown) + request.body = .bytes(AnySendableSequence("1234".utf8), length: .unknown) - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { return } + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } XCTAssertEqual(response.headers["content-length"], []) - guard let body = await XCTAssertNoThrowWithResult( - try await response.body.collect() - ) else { return } + guard + let body = await XCTAssertNoThrowWithResult( + try await response.body.collect(upTo: 1024) + ) + else { return } XCTAssertEqual(body, ByteBuffer(string: "1234")) } - #endif } func testPostWithCollectionOfUInt8() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { let bin = HTTPBin(.http2(compress: false)) { _ in HTTPEchoHandler() } defer { XCTAssertNoThrow(try bin.shutdown()) } @@ -140,23 +168,24 @@ final class AsyncAwaitEndToEndTests: XCTestCase { let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/") request.method = .POST - request.body = .bytes(AnyCollection("1234".utf8), length: .unknown) + request.body = .bytes(AnySendableCollection("1234".utf8), length: .unknown) - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { return } + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } XCTAssertEqual(response.headers["content-length"], []) - guard let body = await XCTAssertNoThrowWithResult( - try await response.body.collect() - ) else { return } + guard + let body = await XCTAssertNoThrowWithResult( + try await response.body.collect(upTo: 1024) + ) + else { return } XCTAssertEqual(body, ByteBuffer(string: "1234")) } - #endif } func testPostWithRandomAccessCollectionOfUInt8() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { let bin = HTTPBin(.http2(compress: false)) { _ in HTTPEchoHandler() } defer { XCTAssertNoThrow(try bin.shutdown()) } @@ -167,21 +196,82 @@ final class AsyncAwaitEndToEndTests: XCTestCase { request.method = .POST request.body = .bytes(ByteBuffer(string: "1234").readableBytesView) - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { return } + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } XCTAssertEqual(response.headers["content-length"], ["4"]) - guard let body = await XCTAssertNoThrowWithResult( - try await response.body.collect() - ) else { return } + guard + let body = await XCTAssertNoThrowWithResult( + try await response.body.collect(upTo: 1024) + ) + else { return } XCTAssertEqual(body, ByteBuffer(string: "1234")) } - #endif + } + + struct AsyncSequenceByteBufferGenerator: AsyncSequence, Sendable, AsyncIteratorProtocol { + typealias Element = ByteBuffer + + let chunkSize: Int + let totalChunks: Int + let buffer: ByteBuffer + var chunksGenerated: Int = 0 + + init(chunkSize: Int, totalChunks: Int) { + self.chunkSize = chunkSize + self.totalChunks = totalChunks + self.buffer = ByteBuffer(repeating: 1, count: self.chunkSize) + } + + mutating func next() async throws -> ByteBuffer? { + guard self.chunksGenerated < self.totalChunks else { return nil } + + self.chunksGenerated += 1 + return self.buffer + } + + func makeAsyncIterator() -> AsyncSequenceByteBufferGenerator { + self + } + } + + func testEchoStreamThatHas3GBInTotal() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let bin = HTTPBin(.http1_1()) { _ in HTTPEchoHandler() } + defer { XCTAssertNoThrow(try bin.shutdown()) } + + let client: HTTPClient = makeDefaultHTTPClient(eventLoopGroupProvider: .shared(eventLoopGroup)) + defer { XCTAssertNoThrow(try client.syncShutdown()) } + + let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) + + var request = HTTPClientRequest(url: "http://localhost:\(bin.port)/") + request.method = .POST + + let sequence = AsyncSequenceByteBufferGenerator( + chunkSize: 4_194_304, // 4MB chunk + totalChunks: 768 // Total = 3GB + ) + request.body = .stream(sequence, length: .unknown) + + let response: HTTPClientResponse = try await client.execute( + request, + deadline: .now() + .seconds(30), + logger: logger + ) + XCTAssertEqual(response.headers["content-length"], []) + + var receivedBytes: Int64 = 0 + for try await part in response.body { + receivedBytes += Int64(part.readableBytes) + } + XCTAssertEqual(receivedBytes, 3_221_225_472) // 3GB } func testPostWithAsyncSequenceOfByteBuffers() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { let bin = HTTPBin(.http2(compress: false)) { _ in HTTPEchoHandler() } defer { XCTAssertNoThrow(try bin.shutdown()) } @@ -190,27 +280,31 @@ final class AsyncAwaitEndToEndTests: XCTestCase { let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/") request.method = .POST - request.body = .stream([ - ByteBuffer(string: "1"), - ByteBuffer(string: "2"), - ByteBuffer(string: "34"), - ].asAsyncSequence(), length: .unknown) + request.body = .stream( + [ + ByteBuffer(string: "1"), + ByteBuffer(string: "2"), + ByteBuffer(string: "34"), + ].async, + length: .unknown + ) - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { return } + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } XCTAssertEqual(response.headers["content-length"], []) - guard let body = await XCTAssertNoThrowWithResult( - try await response.body.collect() - ) else { return } + guard + let body = await XCTAssertNoThrowWithResult( + try await response.body.collect(upTo: 1024) + ) + else { return } XCTAssertEqual(body, ByteBuffer(string: "1234")) } - #endif } func testPostWithAsyncSequenceOfUInt8() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { let bin = HTTPBin(.http2(compress: false)) { _ in HTTPEchoHandler() } defer { XCTAssertNoThrow(try bin.shutdown()) } @@ -219,23 +313,24 @@ final class AsyncAwaitEndToEndTests: XCTestCase { let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/") request.method = .POST - request.body = .stream("1234".utf8.asAsyncSequence(), length: .unknown) + request.body = .stream("1234".utf8.async, length: .unknown) - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { return } + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } XCTAssertEqual(response.headers["content-length"], []) - guard let body = await XCTAssertNoThrowWithResult( - try await response.body.collect() - ) else { return } + guard + let body = await XCTAssertNoThrowWithResult( + try await response.body.collect(upTo: 1024) + ) + else { return } XCTAssertEqual(body, ByteBuffer(string: "1234")) } - #endif } func testPostWithFragmentedAsyncSequenceOfByteBuffers() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { let bin = HTTPBin(.http2(compress: false)) { _ in HTTPEchoHandler() } defer { XCTAssertNoThrow(try bin.shutdown()) } @@ -247,9 +342,11 @@ final class AsyncAwaitEndToEndTests: XCTestCase { let streamWriter = AsyncSequenceWriter() request.body = .stream(streamWriter, length: .unknown) - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { return } + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } XCTAssertEqual(response.headers["content-length"], []) let fragments = [ @@ -260,24 +357,25 @@ final class AsyncAwaitEndToEndTests: XCTestCase { var bodyIterator = response.body.makeAsyncIterator() for expectedFragment in fragments { streamWriter.write(expectedFragment) - guard let actualFragment = await XCTAssertNoThrowWithResult( - try await bodyIterator.next() - ) else { return } + guard + let actualFragment = await XCTAssertNoThrowWithResult( + try await bodyIterator.next() + ) + else { return } XCTAssertEqual(expectedFragment, actualFragment) } streamWriter.end() - guard let lastResult = await XCTAssertNoThrowWithResult( - try await bodyIterator.next() - ) else { return } + guard + let lastResult = await XCTAssertNoThrowWithResult( + try await bodyIterator.next() + ) + else { return } XCTAssertEqual(lastResult, nil) } - #endif } func testPostWithFragmentedAsyncSequenceOfLargeByteBuffers() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { let bin = HTTPBin(.http2(compress: false)) { _ in HTTPEchoHandler() } defer { XCTAssertNoThrow(try bin.shutdown()) } @@ -289,9 +387,11 @@ final class AsyncAwaitEndToEndTests: XCTestCase { let streamWriter = AsyncSequenceWriter() request.body = .stream(streamWriter, length: .unknown) - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { return } + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } XCTAssertEqual(response.headers["content-length"], []) let fragments = [ @@ -303,24 +403,25 @@ final class AsyncAwaitEndToEndTests: XCTestCase { var bodyIterator = response.body.makeAsyncIterator() for expectedFragment in fragments { streamWriter.write(expectedFragment) - guard let actualFragment = await XCTAssertNoThrowWithResult( - try await bodyIterator.next() - ) else { return } + guard + let actualFragment = await XCTAssertNoThrowWithResult( + try await bodyIterator.next() + ) + else { return } XCTAssertEqual(expectedFragment, actualFragment) } streamWriter.end() - guard let lastResult = await XCTAssertNoThrowWithResult( - try await bodyIterator.next() - ) else { return } + guard + let lastResult = await XCTAssertNoThrowWithResult( + try await bodyIterator.next() + ) + else { return } XCTAssertEqual(lastResult, nil) } - #endif } func testCanceling() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest(timeout: 5) { let bin = HTTPBin(.http2(compress: false)) defer { XCTAssertNoThrow(try bin.shutdown()) } @@ -337,15 +438,40 @@ final class AsyncAwaitEndToEndTests: XCTestCase { } task.cancel() await XCTAssertThrowsError(try await task.value) { error in - XCTAssertEqual(error as? HTTPClientError, .cancelled) + XCTAssertTrue(error is CancellationError, "unexpected error \(error)") + } + } + } + + func testCancelingResponseBody() { + XCTAsyncTest(timeout: 5) { + let bin = HTTPBin(.http2(compress: false)) { _ in + HTTPEchoHandler() } + defer { XCTAssertNoThrow(try bin.shutdown()) } + let client = makeDefaultHTTPClient() + defer { XCTAssertNoThrow(try client.syncShutdown()) } + let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) + var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/handler") + request.method = .POST + let streamWriter = AsyncSequenceWriter() + request.body = .stream(streamWriter, length: .unknown) + let response = try await client.execute(request, deadline: .now() + .seconds(2), logger: logger) + streamWriter.write(.init(bytes: [1])) + let task = Task { + try await response.body.collect(upTo: 1024 * 1024) + } + task.cancel() + + await XCTAssertThrowsError(try await task.value) { error in + XCTAssertTrue(error is CancellationError, "unexpected error \(error)") + } + + streamWriter.end() } - #endif } func testDeadline() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest(timeout: 5) { let bin = HTTPBin(.http2(compress: false)) defer { XCTAssertNoThrow(try bin.shutdown()) } @@ -361,16 +487,18 @@ final class AsyncAwaitEndToEndTests: XCTestCase { guard let error = error as? HTTPClientError else { return XCTFail("unexpected error \(error)") } - // a race between deadline and connect timer can result in either error - XCTAssertTrue([.deadlineExceeded, .connectTimeout].contains(error)) + // a race between deadline and connect timer can result in either error. + // If closing happens really fast we might shutdown the pipeline before we fail the request. + // If the pipeline is closed we may receive a `.remoteConnectionClosed`. + XCTAssertTrue( + [.deadlineExceeded, .connectTimeout, .remoteConnectionClosed].contains(error), + "unexpected error \(error)" + ) } } - #endif } func testImmediateDeadline() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest(timeout: 5) { let bin = HTTPBin(.http2(compress: false)) defer { XCTAssertNoThrow(try bin.shutdown()) } @@ -386,32 +514,225 @@ final class AsyncAwaitEndToEndTests: XCTestCase { guard let error = error as? HTTPClientError else { return XCTFail("unexpected error \(error)") } - // a race between deadline and connect timer can result in either error - XCTAssertTrue([.deadlineExceeded, .connectTimeout].contains(error)) + // a race between deadline and connect timer can result in either error. + // If closing happens really fast we might shutdown the pipeline before we fail the request. + // If the pipeline is closed we may receive a `.remoteConnectionClosed`. + XCTAssertTrue( + [.deadlineExceeded, .connectTimeout, .remoteConnectionClosed].contains(error), + "unexpected error \(error)" + ) + } + } + } + + func testConnectTimeout() { + let serverGroup = self.serverGroup! + let clientGroup = self.clientGroup! + XCTAsyncTest(timeout: 60) { + #if os(Linux) + // 198.51.100.254 is reserved for documentation only and therefore should not accept any TCP connection + let url = "http://198.51.100.254/get" + #else + // on macOS we can use the TCP backlog behaviour when the queue is full to simulate a non reachable server. + // this makes this test a bit more stable if `198.51.100.254` actually responds to connection attempt. + // The backlog behaviour on Linux can not be used to simulate a non-reachable server. + // Linux sends a `SYN/ACK` back even if the `backlog` queue is full as it has two queues. + // The second queue is not limit by `ChannelOptions.backlog` but by `/proc/sys/net/ipv4/tcp_max_syn_backlog`. + + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + let serverChannel = try await ServerBootstrap(group: serverGroup) + .serverChannelOption(ChannelOptions.backlog, value: 1) + .serverChannelOption(ChannelOptions.autoRead, value: false) + .bind(host: "127.0.0.1", port: 0) + .get() + defer { + XCTAssertNoThrow(try serverChannel.close().wait()) + } + let port = serverChannel.localAddress!.port! + let firstClientChannel = try await ClientBootstrap(group: serverGroup) + .connect(host: "127.0.0.1", port: port) + .get() + defer { + XCTAssertNoThrow(try firstClientChannel.close().wait()) } + let url = "http://localhost:\(port)/get" + #endif + + let httpClient = HTTPClient( + eventLoopGroupProvider: .shared(clientGroup), + configuration: .init(timeout: .init(connect: .milliseconds(100), read: .milliseconds(150))) + ) + + defer { + XCTAssertNoThrow(try httpClient.syncShutdown()) + } + + let request = HTTPClientRequest(url: url) + let start = NIODeadline.now() + await XCTAssertThrowsError(try await httpClient.execute(request, deadline: .now() + .seconds(30))) { + XCTAssertEqualTypeAndValue($0, HTTPClientError.connectTimeout) + let end = NIODeadline.now() + let duration = end - start + + // We give ourselves 10x slack in order to be confident that even on slow machines this assertion passes. + // It's 30x smaller than our other timeout though. + XCTAssertLessThan(duration, .seconds(1)) + } + } + } + + func testSelfSignedCertificateIsRejectedWithCorrectErrorIfRequestDeadlineIsExceeded() { + XCTAsyncTest(timeout: 5) { + /// key + cert was created with the follwing command: + /// openssl req -x509 -newkey rsa:4096 -keyout self_signed_key.pem -out self_signed_cert.pem -sha256 -days 99999 -nodes -subj '/CN=localhost' + let certPath = Bundle.module.path(forResource: "self_signed_cert", ofType: "pem")! + let keyPath = Bundle.module.path(forResource: "self_signed_key", ofType: "pem")! + let key = try NIOSSLPrivateKey(file: keyPath, format: .pem) + let configuration = TLSConfiguration.makeServerConfiguration( + certificateChain: try NIOSSLCertificate.fromPEMFile(certPath).map { .certificate($0) }, + privateKey: .privateKey(key) + ) + let sslContext = try NIOSSLContext(configuration: configuration) + let serverGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try serverGroup.syncShutdownGracefully()) } + let server = ServerBootstrap(group: serverGroup) + .childChannelInitializer { channel in + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(NIOSSLServerHandler(context: sslContext)) + } + } + let serverChannel = try await server.bind(host: "localhost", port: 0).get() + defer { XCTAssertNoThrow(try serverChannel.close().wait()) } + let port = serverChannel.localAddress!.port! + + let config = HTTPClient.Configuration() + .enableFastFailureModeForTesting() + + let localClient = HTTPClient(eventLoopGroupProvider: .singleton, configuration: config) + defer { XCTAssertNoThrow(try localClient.syncShutdown()) } + let request = HTTPClientRequest(url: "https://localhost:\(port)") + await XCTAssertThrowsError(try await localClient.execute(request, deadline: .now() + .seconds(2))) { + error in + #if canImport(Network) + guard let nwTLSError = error as? HTTPClient.NWTLSError else { + XCTFail("could not cast \(error) of type \(type(of: error)) to \(HTTPClient.NWTLSError.self)") + return + } + XCTAssertEqual(nwTLSError.status, errSSLBadCert, "unexpected tls error: \(nwTLSError)") + #else + guard let sslError = error as? NIOSSLError, + case .handshakeFailed(.sslError) = sslError + else { + XCTFail("unexpected error \(error)") + return + } + #endif + } + } + } + + func testDnsOverride() { + XCTAsyncTest(timeout: 5) { + // key + cert was created with the following code (depends on swift-certificates) + // ``` + // import X509 + // import CryptoKit + // import Foundation + // + // let privateKey = P384.Signing.PrivateKey() + // let name = try DistinguishedName { + // OrganizationName("Self Signed") + // CommonName("localhost") + // } + // let certificate = try Certificate( + // version: .v3, + // serialNumber: .init(), + // publicKey: .init(privateKey.publicKey), + // notValidBefore: Date(), + // notValidAfter: Date().advanced(by: 365 * 24 * 3600), + // issuer: name, + // subject: name, + // signatureAlgorithm: .ecdsaWithSHA384, + // extensions: try .init { + // SubjectAlternativeNames([.dnsName("example.com")]) + // try ExtendedKeyUsage([.serverAuth]) + // }, + // issuerPrivateKey: .init(privateKey) + // ) + // ``` + let certPath = Bundle.module.path(forResource: "example.com.cert", ofType: "pem")! + let keyPath = Bundle.module.path(forResource: "example.com.private-key", ofType: "pem")! + let key = try NIOSSLPrivateKey(file: keyPath, format: .pem) + let localhostCert = try NIOSSLCertificate.fromPEMFile(certPath) + let configuration = TLSConfiguration.makeServerConfiguration( + certificateChain: localhostCert.map { .certificate($0) }, + privateKey: .privateKey(key) + ) + let bin = HTTPBin(.http2(tlsConfiguration: configuration)) + defer { XCTAssertNoThrow(try bin.shutdown()) } + + var config = HTTPClient.Configuration() + .enableFastFailureModeForTesting() + var tlsConfig = TLSConfiguration.makeClientConfiguration() + + tlsConfig.trustRoots = .certificates(localhostCert) + config.tlsConfiguration = tlsConfig + // this is the actual configuration under test + config.dnsOverride = ["example.com": "localhost"] + + let localClient = HTTPClient(eventLoopGroupProvider: .singleton, configuration: config) + defer { XCTAssertNoThrow(try localClient.syncShutdown()) } + let request = HTTPClientRequest(url: "https://example.com:\(bin.port)/echohostheader") + let response = await XCTAssertNoThrowWithResult( + try await localClient.execute(request, deadline: .now() + .seconds(2)) + ) + XCTAssertEqual(response?.status, .ok) + XCTAssertEqual(response?.version, .http2) + var body = try await response?.body.collect(upTo: 1024) + let readableBytes = body?.readableBytes ?? 0 + let responseInfo = try body?.readJSONDecodable(RequestInfo.self, length: readableBytes) + XCTAssertEqual(responseInfo?.data, "example.com\(bin.port == 443 ? "" : ":\(bin.port)")") } - #endif } func testInvalidURL() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest(timeout: 5) { let client = makeDefaultHTTPClient() defer { XCTAssertNoThrow(try client.syncShutdown()) } let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) - let request = HTTPClientRequest(url: "") // invalid URL + let request = HTTPClientRequest(url: "") // invalid URL - await XCTAssertThrowsError(try await client.execute(request, deadline: .now() + .seconds(2), logger: logger)) { + await XCTAssertThrowsError( + try await client.execute(request, deadline: .now() + .seconds(2), logger: logger) + ) { XCTAssertEqual($0 as? HTTPClientError, .invalidURL) } } - #endif + } + + func testInsanelyHighConcurrentHTTP1ConnectionLimitDoesNotCrash() async throws { + let bin = HTTPBin(.http1_1(compress: false)) + defer { XCTAssertNoThrow(try bin.shutdown()) } + + var httpClientConfig = HTTPClient.Configuration() + httpClientConfig.connectionPool = .init( + idleTimeout: .hours(1), + concurrentHTTP1ConnectionsPerHostSoftLimit: Int.max + ) + httpClientConfig.timeout = .init(connect: .seconds(10), read: .seconds(100), write: .seconds(100)) + + let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: httpClientConfig) + defer { XCTAssertNoThrow(try httpClient.syncShutdown()) } + + let request = HTTPClientRequest(url: "http://localhost:\(bin.port)") + _ = try await httpClient.execute(request, deadline: .now() + .seconds(2)) } func testRedirectChangesHostHeader() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { let bin = HTTPBin(.http2(compress: false)) defer { XCTAssertNoThrow(try bin.shutdown()) } @@ -419,28 +740,35 @@ final class AsyncAwaitEndToEndTests: XCTestCase { defer { XCTAssertNoThrow(try client.syncShutdown()) } let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) var request = HTTPClientRequest(url: "https://127.0.0.1:\(bin.port)/redirect/target") - request.headers.replaceOrAdd(name: "X-Target-Redirect-URL", value: "https://localhost:\(bin.port)/echohostheader") + let redirectURL = "https://localhost:\(bin.port)/echohostheader" + request.headers.replaceOrAdd( + name: "X-Target-Redirect-URL", + value: redirectURL + ) - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { + return + } + guard let body = await XCTAssertNoThrowWithResult(try await response.body.collect(upTo: 1024)) else { return } - guard let body = await XCTAssertNoThrowWithResult(try await response.body.collect()) else { return } var maybeRequestInfo: RequestInfo? XCTAssertNoThrow(maybeRequestInfo = try JSONDecoder().decode(RequestInfo.self, from: body)) guard let requestInfo = maybeRequestInfo else { return } + XCTAssertEqual(response.url?.absoluteString, redirectURL) + XCTAssertEqual(response.history.map(\.request.url), [request.url, redirectURL]) XCTAssertEqual(response.status, .ok) XCTAssertEqual(response.version, .http2) XCTAssertEqual(requestInfo.data, "localhost:\(bin.port)") } - #endif } func testShutdown() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { let client = makeDefaultHTTPClient() try await client.shutdown() @@ -448,17 +776,289 @@ final class AsyncAwaitEndToEndTests: XCTestCase { XCTAssertEqualTypeAndValue(error, HTTPClientError.alreadyShutdown) } } - #endif } -} -#if compiler(>=5.5.2) && canImport(_Concurrency) -extension AsyncSequence where Element == ByteBuffer { - func collect() async rethrows -> ByteBuffer { - try await self.reduce(into: ByteBuffer()) { accumulatingBuffer, nextBuffer in - var nextBuffer = nextBuffer - accumulatingBuffer.writeBuffer(&nextBuffer) + /// Regression test for https://github.com/swift-server/async-http-client/issues/612 + func testCancelingBodyDoesNotCrash() { + XCTAsyncTest { + let client = makeDefaultHTTPClient() + defer { XCTAssertNoThrow(try client.syncShutdown()) } + let bin = HTTPBin(.http2(compress: true)) + defer { XCTAssertNoThrow(try bin.shutdown()) } + + let request = HTTPClientRequest(url: "https://127.0.0.1:\(bin.port)/mega-chunked") + let response = try await client.execute(request, deadline: .now() + .seconds(10)) + + await XCTAssertThrowsError(try await response.body.collect(upTo: 100)) { error in + XCTAssert(error is NIOTooManyBytesError) + } + } + } + + func testAsyncSequenceReuse() { + XCTAsyncTest { + let bin = HTTPBin(.http2(compress: false)) { _ in HTTPEchoHandler() } + defer { XCTAssertNoThrow(try bin.shutdown()) } + let client = makeDefaultHTTPClient() + defer { XCTAssertNoThrow(try client.syncShutdown()) } + let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) + var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/") + request.method = .POST + request.body = .stream( + [ + ByteBuffer(string: "1"), + ByteBuffer(string: "2"), + ByteBuffer(string: "34"), + ].async, + length: .unknown + ) + + guard + let response1 = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } + XCTAssertEqual(response1.headers["content-length"], []) + guard + let body = await XCTAssertNoThrowWithResult( + try await response1.body.collect(upTo: 1024) + ) + else { return } + XCTAssertEqual(body, ByteBuffer(string: "1234")) + + guard + let response2 = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } + XCTAssertEqual(response2.headers["content-length"], []) + guard + let body = await XCTAssertNoThrowWithResult( + try await response2.body.collect(upTo: 1024) + ) + else { return } + XCTAssertEqual(body, ByteBuffer(string: "1234")) + } + } + + func testRejectsInvalidCharactersInHeaderFieldNames_http1() { + self._rejectsInvalidCharactersInHeaderFieldNames(mode: .http1_1(ssl: true)) + } + + func testRejectsInvalidCharactersInHeaderFieldNames_http2() { + self._rejectsInvalidCharactersInHeaderFieldNames(mode: .http2(compress: false)) + } + + private func _rejectsInvalidCharactersInHeaderFieldNames(mode: HTTPBin.Mode) { + XCTAsyncTest { + let bin = HTTPBin(mode) + defer { XCTAssertNoThrow(try bin.shutdown()) } + let client = makeDefaultHTTPClient() + defer { XCTAssertNoThrow(try client.syncShutdown()) } + let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) + + // The spec in [RFC 9110](https://httpwg.org/specs/rfc9110.html#fields.values) defines the valid + // characters as the following: + // + // ``` + // field-name = token + // + // token = 1*tchar + // + // tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" + // / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~" + // / DIGIT / ALPHA + // ; any VCHAR, except delimiters + let weirdAllowedFieldName = "!#$%&'*+-.^_`|~0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + + var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/get") + request.headers.add(name: weirdAllowedFieldName, value: "present") + + // This should work fine. + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { + return + } + + XCTAssertEqual(response.status, .ok) + + // Now, let's confirm all other bytes are rejected. We want to stay within the ASCII space as the HTTPHeaders type will forbid anything else. + for byte in UInt8(0)...UInt8(127) { + // Skip bytes that we already believe are allowed. + if weirdAllowedFieldName.utf8.contains(byte) { + continue + } + let forbiddenFieldName = weirdAllowedFieldName + String(decoding: [byte], as: UTF8.self) + + var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/get") + request.headers.add(name: forbiddenFieldName, value: "present") + + await XCTAssertThrowsError( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) { error in + XCTAssertEqual(error as? HTTPClientError, .invalidHeaderFieldNames([forbiddenFieldName])) + } + } + } + } + + func testRejectsInvalidCharactersInHeaderFieldValues_http1() { + self._rejectsInvalidCharactersInHeaderFieldValues(mode: .http1_1(ssl: true)) + } + + func testRejectsInvalidCharactersInHeaderFieldValues_http2() { + self._rejectsInvalidCharactersInHeaderFieldValues(mode: .http2(compress: false)) + } + + private func _rejectsInvalidCharactersInHeaderFieldValues(mode: HTTPBin.Mode) { + XCTAsyncTest { + let bin = HTTPBin(mode) + defer { XCTAssertNoThrow(try bin.shutdown()) } + let client = makeDefaultHTTPClient() + defer { XCTAssertNoThrow(try client.syncShutdown()) } + let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) + + // We reject all ASCII control characters except HTAB and tolerate everything else. + let weirdAllowedFieldValue = + "!\" \t#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~" + + var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/get") + request.headers.add(name: "Weird-Value", value: weirdAllowedFieldValue) + + // This should work fine. + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { + return + } + + XCTAssertEqual(response.status, .ok) + + // Now, let's confirm all other bytes in the ASCII range ar rejected + for byte in UInt8(0)...UInt8(127) { + // Skip bytes that we already believe are allowed. + if weirdAllowedFieldValue.utf8.contains(byte) { + continue + } + let forbiddenFieldValue = weirdAllowedFieldValue + String(decoding: [byte], as: UTF8.self) + + var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/get") + request.headers.add(name: "Weird-Value", value: forbiddenFieldValue) + + await XCTAssertThrowsError( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) { error in + XCTAssertEqual(error as? HTTPClientError, .invalidHeaderFieldValues([forbiddenFieldValue])) + } + } + + // All the bytes outside the ASCII range are fine though. + for byte in UInt8(128)...UInt8(255) { + let evenWeirderAllowedValue = weirdAllowedFieldValue + String(decoding: [byte], as: UTF8.self) + + var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/get") + request.headers.add(name: "Weird-Value", value: evenWeirderAllowedValue) + + // This should work fine. + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { + return + } + + XCTAssertEqual(response.status, .ok) + } + } + } + + func testUsingGetMethodInsteadOfWait() { + XCTAsyncTest { + let bin = HTTPBin(.http2(compress: false)) + defer { XCTAssertNoThrow(try bin.shutdown()) } + let client = makeDefaultHTTPClient() + defer { XCTAssertNoThrow(try client.syncShutdown()) } + let request = try HTTPClient.Request(url: "https://localhost:\(bin.port)/get") + + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request: request).get() + ) + else { + return + } + + XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.version, .http2) + } + } + + func testSimpleContentLengthErrorNoBody() { + XCTAsyncTest { + let bin = HTTPBin(.http2(compress: false)) + defer { XCTAssertNoThrow(try bin.shutdown()) } + let client = makeDefaultHTTPClient() + defer { XCTAssertNoThrow(try client.syncShutdown()) } + let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) + let request = HTTPClientRequest(url: "https://localhost:\(bin.port)/content-length-without-body") + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } + await XCTAssertThrowsError( + try await response.body.collect(upTo: 3) + ) { + XCTAssertEqualTypeAndValue($0, NIOTooManyBytesError(maxBytes: 3)) + } } } } -#endif + +struct AnySendableSequence: @unchecked Sendable { + private let wrapped: AnySequence + init( + _ sequence: WrappedSequence + ) where WrappedSequence.Element == Element { + self.wrapped = .init(sequence) + } +} + +extension AnySendableSequence: Sequence { + func makeIterator() -> AnySequence.Iterator { + self.wrapped.makeIterator() + } +} + +struct AnySendableCollection: @unchecked Sendable { + private let wrapped: AnyCollection + init( + _ collection: WrappedCollection + ) where WrappedCollection.Element == Element { + self.wrapped = .init(collection) + } +} + +extension AnySendableCollection: Collection { + var startIndex: AnyCollection.Index { + self.wrapped.startIndex + } + + var endIndex: AnyCollection.Index { + self.wrapped.endIndex + } + + func index(after i: AnyIndex) -> AnyIndex { + self.wrapped.index(after: i) + } + + subscript(position: AnyCollection.Index) -> Element { + self.wrapped[position] + } +} diff --git a/Tests/AsyncHTTPClientTests/AsyncTestHelpers.swift b/Tests/AsyncHTTPClientTests/AsyncTestHelpers.swift index 312008959..5e063be81 100644 --- a/Tests/AsyncHTTPClientTests/AsyncTestHelpers.swift +++ b/Tests/AsyncHTTPClientTests/AsyncTestHelpers.swift @@ -12,12 +12,12 @@ // //===----------------------------------------------------------------------===// -#if compiler(>=5.5.2) && canImport(_Concurrency) import NIOConcurrencyHelpers import NIOCore +/// ``AsyncSequenceWriter`` is `Sendable` because its state is protected by a Lock @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -class AsyncSequenceWriter: AsyncSequence { +final class AsyncSequenceWriter: AsyncSequence, @unchecked Sendable { typealias AsyncIterator = Iterator struct Iterator: AsyncIteratorProtocol { @@ -33,7 +33,7 @@ class AsyncSequenceWriter: AsyncSequence { } func makeAsyncIterator() -> Iterator { - return Iterator(self) + Iterator(self) } private enum State { @@ -43,12 +43,11 @@ class AsyncSequenceWriter: AsyncSequence { case failed(Error, CheckedContinuation?) } - private var _state = State.buffering(.init(), nil) - private let lock = Lock() + private let state = NIOLockedValueBox(.buffering([], nil)) public var hasDemand: Bool { - self.lock.withLock { - switch self._state { + self.state.withLockedValue { state in + switch state { case .failed, .finished, .buffering: return false case .waiting: @@ -59,65 +58,132 @@ class AsyncSequenceWriter: AsyncSequence { /// Wait until a downstream consumer has issued more demand by calling `next`. public func demand() async { - self.lock.lock() + let shouldBuffer = self.state.withLockedValue { state in + switch state { + case .buffering(_, .none): + return true + case .waiting: + return false + case .buffering(_, .some), .failed(_, .some): + preconditionFailure("Already waiting for demand. Invalid state: \(state)") + case .finished, .failed: + preconditionFailure("Invalid state: \(state)") + } + } - switch self._state { - case .buffering(let buffer, .none): + if shouldBuffer { await withCheckedContinuation { (continuation: CheckedContinuation) in - self._state = .buffering(buffer, continuation) - self.lock.unlock() + let shouldResumeContinuation = self.state.withLockedValue { state in + switch state { + case .buffering(let buffer, .none): + state = .buffering(buffer, continuation) + return false + case .waiting: + return true + case .buffering(_, .some), .failed(_, .some): + preconditionFailure("Already waiting for demand. Invalid state: \(state)") + case .finished, .failed: + preconditionFailure("Invalid state: \(state)") + } + } + + if shouldResumeContinuation { + continuation.resume() + } } - - case .waiting: - self.lock.unlock() - return - - case .buffering(_, .some), .failed(_, .some): - let state = self._state - self.lock.unlock() - preconditionFailure("Already waiting for demand. Invalid state: \(state)") - - case .finished, .failed: - let state = self._state - self.lock.unlock() - preconditionFailure("Invalid state: \(state)") } } + private enum NextAction { + /// Resume the continuation if present, and return the result if present. + case resumeAndReturn(CheckedContinuation?, Result?) + /// Suspend the current task and wait for the next value. + case suspend + } + private func next() async throws -> Element? { - self.lock.lock() - switch self._state { - case .buffering(let buffer, let demandContinuation) where buffer.isEmpty: - return try await withCheckedThrowingContinuation { continuation in - self._state = .waiting(continuation) - self.lock.unlock() - demandContinuation?.resume(returning: ()) - } + let action: NextAction = self.state.withLockedValue { state in + switch state { + case .buffering(var buffer, let demandContinuation): + if buffer.isEmpty { + return .suspend + } else { + let first = buffer.removeFirst() + if first != nil { + state = .buffering(buffer, demandContinuation) + } else { + state = .finished + } + return .resumeAndReturn(nil, .success(first)) + } + + case .failed(let error, let demandContinuation): + state = .finished + return .resumeAndReturn(demandContinuation, .failure(error)) + + case .finished: + return .resumeAndReturn(nil, .success(nil)) - case .buffering(var buffer, let demandContinuation): - let first = buffer.removeFirst() - if first != nil { - self._state = .buffering(buffer, demandContinuation) - } else { - self._state = .finished + case .waiting: + preconditionFailure( + "Expected that there is always only one concurrent call to next. Invalid state: \(state)" + ) } - self.lock.unlock() - return first + } - case .failed(let error, let demandContinuation): - self._state = .finished - self.lock.unlock() + switch action { + case .resumeAndReturn(let demandContinuation, let result): demandContinuation?.resume() - throw error - - case .finished: - self.lock.unlock() - return nil - - case .waiting: - let state = self._state - self.lock.unlock() - preconditionFailure("Expected that there is always only one concurrent call to next. Invalid state: \(state)") + return try result?.get() + + case .suspend: + // Holding the lock here *should* be safe but because of a bug in the runtime + // it isn't, so drop the lock, create the continuation and then try again. + // + // See https://github.com/swiftlang/swift/issues/85668 + return try await withCheckedThrowingContinuation { + (continuation: CheckedContinuation) in + let action: NextAction = self.state.withLockedValue { state in + switch state { + case .buffering(var buffer, let demandContinuation): + if buffer.isEmpty { + state = .waiting(continuation) + return .resumeAndReturn(demandContinuation, nil) + } else { + let first = buffer.removeFirst() + if first != nil { + state = .buffering(buffer, demandContinuation) + } else { + state = .finished + } + return .resumeAndReturn(nil, .success(first)) + } + + case .failed(let error, let demandContinuation): + state = .finished + return .resumeAndReturn(demandContinuation, .failure(error)) + + case .finished: + return .resumeAndReturn(nil, .success(nil)) + + case .waiting: + preconditionFailure( + "Expected that there is always only one concurrent call to next. Invalid state: \(state)" + ) + } + } + + switch action { + case .resumeAndReturn(let demandContinuation, let result): + demandContinuation?.resume() + // Resume the continuation rather than returning th result. + if let result { + continuation.resume(with: result) + } + case .suspend: + preconditionFailure() // Not returned from the code above. + } + } } } @@ -135,19 +201,19 @@ class AsyncSequenceWriter: AsyncSequence { } private func writeBufferOrEnd(_ element: Element?) { - let writeAction = self.lock.withLock { () -> WriteAction in - switch self._state { + let writeAction = self.state.withLockedValue { state -> WriteAction in + switch state { case .buffering(var buffer, let continuation): buffer.append(element) - self._state = .buffering(buffer, continuation) + state = .buffering(buffer, continuation) return .none case .waiting(let continuation): - self._state = .buffering(.init(), nil) + state = .buffering(.init(), nil) return .succeedContinuation(continuation, element) case .finished, .failed: - preconditionFailure("Invalid state: \(self._state)") + preconditionFailure("Invalid state: \(state)") } } @@ -168,17 +234,17 @@ class AsyncSequenceWriter: AsyncSequence { /// Drops all buffered writes and emits an error on the waiting `next`. If there is no call to `next` /// waiting, will emit the error on the next call to `next`. public func fail(_ error: Error) { - let errorAction = self.lock.withLock { () -> ErrorAction in - switch self._state { + let errorAction = self.state.withLockedValue { state -> ErrorAction in + switch state { case .buffering(_, let demandContinuation): - self._state = .failed(error, demandContinuation) + state = .failed(error, demandContinuation) return .none case .failed, .finished: return .none case .waiting(let continuation): - self._state = .finished + state = .finished return .failContinuation(continuation, error) } } @@ -191,4 +257,3 @@ class AsyncSequenceWriter: AsyncSequence { } } } -#endif diff --git a/Tests/AsyncHTTPClientTests/ConnectionPoolSizeConfigValueIsRespectedTests.swift b/Tests/AsyncHTTPClientTests/ConnectionPoolSizeConfigValueIsRespectedTests.swift new file mode 100644 index 000000000..962791334 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/ConnectionPoolSizeConfigValueIsRespectedTests.swift @@ -0,0 +1,76 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2022 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import AsyncHTTPClient +import Atomics +import Logging +import NIOConcurrencyHelpers +import NIOCore +import NIOFoundationCompat +import NIOHTTP1 +import NIOHTTPCompression +import NIOPosix +import NIOSSL +import NIOTestUtils +import NIOTransportServices +import XCTest + +#if canImport(Network) +import Network +#endif + +final class ConnectionPoolSizeConfigValueIsRespectedTests: XCTestCaseHTTPClientTestsBaseClass { + func testConnectionPoolSizeConfigValueIsRespected() { + let numberOfRequestsPerThread = 1000 + let numberOfParallelWorkers = 16 + let poolSize = 12 + + let httpBin = HTTPBin() + defer { XCTAssertNoThrow(try httpBin.shutdown()) } + + let group = MultiThreadedEventLoopGroup(numberOfThreads: 4) + defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } + + let configuration = HTTPClient.Configuration( + connectionPool: .init( + idleTimeout: .seconds(30), + concurrentHTTP1ConnectionsPerHostSoftLimit: poolSize + ) + ) + let client = HTTPClient(eventLoopGroupProvider: .shared(group), configuration: configuration) + defer { XCTAssertNoThrow(try client.syncShutdown()) } + + let g = DispatchGroup() + for workerID in 0.. Void = { _ in }) throws { let part = try self.readOutbound(as: HTTPClientRequestPart.self) @@ -58,7 +59,7 @@ extension EmbeddedChannel { } struct HTTP1TestTools { - let connection: HTTP1Connection + let connection: HTTP1Connection.SendableView let connectionDelegate: MockConnectionDelegate let readEventHandler: ReadEventHitHandler let logger: Logger @@ -77,7 +78,7 @@ extension EmbeddedChannel { channel: self, connectionID: 1, delegate: connectionDelegate, - configuration: .init(), + decompression: .disabled, logger: logger ) @@ -86,8 +87,8 @@ extension EmbeddedChannel { let decoder = try self.pipeline.syncOperations.handler(type: ByteToMessageHandler.self) let encoder = try self.pipeline.syncOperations.handler(type: HTTPRequestEncoder.self) - let removeDecoderFuture = self.pipeline.removeHandler(decoder) - let removeEncoderFuture = self.pipeline.removeHandler(encoder) + let removeDecoderFuture = self.pipeline.syncOperations.removeHandler(decoder) + let removeEncoderFuture = self.pipeline.syncOperations.removeHandler(encoder) self.embeddedEventLoop.run() @@ -95,7 +96,7 @@ extension EmbeddedChannel { try removeEncoderFuture.wait() return .init( - connection: connection, + connection: connection.sendableView, connectionDelegate: connectionDelegate, readEventHandler: readEventHandler, logger: logger @@ -111,6 +112,6 @@ public struct HTTP1EmbeddedChannelError: Error, Hashable, CustomStringConvertibl } public var description: String { - return self.reason + self.reason } } diff --git a/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests+XCTest.swift deleted file mode 100644 index 86707520c..000000000 --- a/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests+XCTest.swift +++ /dev/null @@ -1,37 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTP1ClientChannelHandlerTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTP1ClientChannelHandlerTests { - static var allTests: [(String, (HTTP1ClientChannelHandlerTests) -> () throws -> Void)] { - return [ - ("testResponseBackpressure", testResponseBackpressure), - ("testWriteBackpressure", testWriteBackpressure), - ("testClientHandlerCancelsRequestIfWeWantToShutdown", testClientHandlerCancelsRequestIfWeWantToShutdown), - ("testIdleReadTimeout", testIdleReadTimeout), - ("testIdleReadTimeoutIsCanceledIfRequestIsCanceled", testIdleReadTimeoutIsCanceledIfRequestIsCanceled), - ("testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForDemand", testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForDemand), - ("testWriteHTTPHeadFails", testWriteHTTPHeadFails), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift b/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift index 4769d2c7e..0d871b7dc 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift @@ -12,13 +12,15 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging +import NIOConcurrencyHelpers import NIOCore import NIOEmbedded import NIOHTTP1 import XCTest +@testable import AsyncHTTPClient + class HTTP1ClientChannelHandlerTests: XCTestCase { func testResponseBackpressure() { let embedded = EmbeddedChannel() @@ -32,27 +34,35 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } testUtils.connection.executeRequest(requestBag) - XCTAssertNoThrow(try embedded.receiveHeadAndVerify { - XCTAssertEqual($0.method, .GET) - XCTAssertEqual($0.uri, "/") - XCTAssertEqual($0.headers.first(name: "host"), "localhost") - }) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + } + ) XCTAssertEqual(try embedded.readOutbound(as: HTTPClientRequestPart.self), .end(nil)) - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "12")]) + ) XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 0) embedded.read() @@ -113,22 +123,30 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/", method: .POST, body: .stream(length: 100) { writer in - testWriter.start(writer: writer) - })) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/", + method: .POST, + body: .stream(contentLength: 100) { writer in + testWriter.start(writer: writer) + } + ) + ) guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } // the handler only writes once the channel is writable @@ -143,12 +161,14 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { testWriter.writabilityChanged(true) embedded.pipeline.fireChannelWritabilityChanged() - XCTAssertNoThrow(try embedded.receiveHeadAndVerify { - XCTAssertEqual($0.method, .POST) - XCTAssertEqual($0.uri, "/") - XCTAssertEqual($0.headers.first(name: "host"), "localhost") - XCTAssertEqual($0.headers.first(name: "content-length"), "100") - }) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .POST) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + XCTAssertEqual($0.headers.first(name: "content-length"), "100") + } + ) // the next body write will be executed once we tick the el. before we make the channel // unwritable @@ -162,9 +182,11 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { embedded.embeddedEventLoop.run() - XCTAssertNoThrow(try embedded.receiveBodyAndVerify { - XCTAssertEqual($0.readableBytes, 2) - }) + XCTAssertNoThrow( + try embedded.receiveBodyAndVerify { + XCTAssertEqual($0.readableBytes, 2) + } + ) XCTAssertEqual(testWriter.written, index + 1) @@ -201,24 +223,28 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } testUtils.connection.executeRequest(requestBag) - XCTAssertNoThrow(try embedded.receiveHeadAndVerify { - XCTAssertEqual($0.method, .GET) - XCTAssertEqual($0.uri, "/") - XCTAssertEqual($0.headers.first(name: "host"), "localhost") - }) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + } + ) XCTAssertNoThrow(try embedded.receiveEnd()) XCTAssertTrue(embedded.isActive) @@ -247,27 +273,35 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } testUtils.connection.executeRequest(requestBag) - XCTAssertNoThrow(try embedded.receiveHeadAndVerify { - XCTAssertEqual($0.method, .GET) - XCTAssertEqual($0.uri, "/") - XCTAssertEqual($0.headers.first(name: "host"), "localhost") - }) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + } + ) XCTAssertNoThrow(try embedded.receiveEnd()) - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "12")]) + ) XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 0) embedded.read() @@ -299,27 +333,35 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } testUtils.connection.executeRequest(requestBag) - XCTAssertNoThrow(try embedded.receiveHeadAndVerify { - XCTAssertEqual($0.method, .GET) - XCTAssertEqual($0.uri, "/") - XCTAssertEqual($0.headers.first(name: "host"), "localhost") - }) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + } + ) XCTAssertNoThrow(try embedded.receiveEnd()) - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "12")]) + ) XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 0) embedded.read() @@ -327,7 +369,7 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.head(responseHead))) // canceling the request - requestBag.cancel() + requestBag.fail(HTTPClientError.cancelled) XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { XCTAssertEqual($0 as? HTTPClientError, .cancelled) } @@ -337,6 +379,217 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { embedded.embeddedEventLoop.advanceTime(by: .milliseconds(250)) } + func testIdleWriteTimeout() { + let embedded = EmbeddedChannel() + let testWriter = TestBackpressureWriter(eventLoop: embedded.eventLoop, parts: 5) + var maybeTestUtils: HTTP1TestTools? + XCTAssertNoThrow(maybeTestUtils = try embedded.setupHTTP1Connection()) + guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/", + method: .POST, + body: .stream(contentLength: 10) { writer in + // Advance time by more than the idle write timeout (that's 1 millisecond) to trigger the timeout. + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(2)) + return testWriter.start(writer: writer) + } + ) + ) + + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseAccumulator(request: request) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleWriteTimeout: .milliseconds(1)), + delegate: delegate + ) + ) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + embedded.isWritable = true + testWriter.writabilityChanged(true) + embedded.pipeline.fireChannelWritabilityChanged() + testUtils.connection.executeRequest(requestBag) + + XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .writeTimeout) + } + } + + func testIdleWriteTimeoutRaceToEnd() { + let embedded = EmbeddedChannel() + var maybeTestUtils: HTTP1TestTools? + XCTAssertNoThrow(maybeTestUtils = try embedded.setupHTTP1Connection()) + guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/", + method: .POST, + body: .stream { _ in + // Advance time by more than the idle write timeout (that's 1 millisecond) to trigger the timeout. + let scheduled = embedded.embeddedEventLoop.flatScheduleTask(in: .milliseconds(2)) { + embedded.embeddedEventLoop.makeSucceededVoidFuture() + } + return scheduled.futureResult + } + ) + ) + + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseAccumulator(request: request) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleWriteTimeout: .milliseconds(5)), + delegate: delegate + ) + ) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + embedded.isWritable = true + embedded.pipeline.fireChannelWritabilityChanged() + testUtils.connection.executeRequest(requestBag) + let expectedHeaders: HTTPHeaders = ["host": "localhost", "Transfer-Encoding": "chunked"] + XCTAssertEqual( + try embedded.readOutbound(as: HTTPClientRequestPart.self), + .head(HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: expectedHeaders)) + ) + + // change the writability to false. + embedded.isWritable = false + embedded.pipeline.fireChannelWritabilityChanged() + embedded.embeddedEventLoop.run() + + // let the writer, write an end (while writability is false) + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(2)) + + XCTAssertEqual(try embedded.readOutbound(as: HTTPClientRequestPart.self), .end(nil)) + } + + func testIdleWriteTimeoutWritabilityChanged() { + let embedded = EmbeddedChannel() + let testWriter = TestBackpressureWriter(eventLoop: embedded.eventLoop, parts: 5) + var maybeTestUtils: HTTP1TestTools? + XCTAssertNoThrow(maybeTestUtils = try embedded.setupHTTP1Connection()) + guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/", + method: .POST, + body: .stream(contentLength: 10) { writer in + embedded.isWritable = false + embedded.pipeline.fireChannelWritabilityChanged() + // This should not trigger any errors or timeouts, because the timer isn't running + // as the channel is not writable. + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(20)) + + // Now that the channel will become writable, this should trigger a timeout. + embedded.isWritable = true + embedded.pipeline.fireChannelWritabilityChanged() + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(2)) + + return testWriter.start(writer: writer) + } + ) + ) + + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseAccumulator(request: request) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleWriteTimeout: .milliseconds(1)), + delegate: delegate + ) + ) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + embedded.isWritable = true + testWriter.writabilityChanged(true) + embedded.pipeline.fireChannelWritabilityChanged() + testUtils.connection.executeRequest(requestBag) + + XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .writeTimeout) + } + } + + func testIdleWriteTimeoutIsCancelledIfRequestIsCancelled() { + let embedded = EmbeddedChannel() + let testWriter = TestBackpressureWriter(eventLoop: embedded.eventLoop, parts: 1) + var maybeTestUtils: HTTP1TestTools? + XCTAssertNoThrow(maybeTestUtils = try embedded.setupHTTP1Connection()) + guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/", + method: .POST, + body: .stream(contentLength: 2) { writer in + testWriter.start(writer: writer, expectedErrors: [HTTPClientError.cancelled]) + } + ) + ) + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseAccumulator(request: request) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleWriteTimeout: .milliseconds(1)), + delegate: delegate + ) + ) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + embedded.isWritable = true + testWriter.writabilityChanged(true) + embedded.pipeline.fireChannelWritabilityChanged() + testUtils.connection.executeRequest(requestBag) + + // canceling the request + requestBag.fail(HTTPClientError.cancelled) + XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .cancelled) + } + + // the idle write timeout should be cleared because we canceled the request + // therefore advancing the time should not trigger a crash + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(250)) + } + func testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForDemand() { let embedded = EmbeddedChannel() var maybeTestUtils: HTTP1TestTools? @@ -349,27 +602,35 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } testUtils.connection.executeRequest(requestBag) - XCTAssertNoThrow(try embedded.receiveHeadAndVerify { - XCTAssertEqual($0.method, .GET) - XCTAssertEqual($0.uri, "/") - XCTAssertEqual($0.headers.first(name: "host"), "localhost") - }) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + } + ) XCTAssertNoThrow(try embedded.receiveEnd()) - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "50")])) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "50")]) + ) XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 0) embedded.read() @@ -420,7 +681,12 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { XCTAssertNoThrow(maybeTestUtils = try embedded.setupHTTP1Connection()) guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } - XCTAssertNoThrow(try embedded.pipeline.syncOperations.addHandler(FailWriteHandler(), position: .after(testUtils.readEventHandler))) + XCTAssertNoThrow( + try embedded.pipeline.syncOperations.addHandler( + FailWriteHandler(), + position: .after(testUtils.readEventHandler) + ) + ) let logger = Logger(label: "test") @@ -430,16 +696,20 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), - delegate: delegate - )) - guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), + delegate: delegate + ) + ) + guard let requestBag = maybeRequestBag else { + return XCTFail("Expected to be able to create a request bag") + } embedded.isWritable = false XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) @@ -457,42 +727,229 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { XCTAssertEqual(embedded.isActive, false) } } + + func testHandlerClosesChannelIfLastActionIsSendEndAndItFails() { + let embedded = EmbeddedChannel() + let testWriter = TestBackpressureWriter(eventLoop: embedded.eventLoop, parts: 5) + var maybeTestUtils: HTTP1TestTools? + XCTAssertNoThrow(maybeTestUtils = try embedded.setupHTTP1Connection()) + guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/", + method: .POST, + body: .stream(contentLength: 10) { writer in + testWriter.start(writer: writer) + } + ) + ) + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseAccumulator(request: request) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), + delegate: delegate + ) + ) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + XCTAssertNoThrow(try embedded.pipeline.addHandler(FailEndHandler(), position: .first).wait()) + + // Execute the request and we'll receive the head. + testWriter.writabilityChanged(true) + testUtils.connection.executeRequest(requestBag) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .POST) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + XCTAssertEqual($0.headers.first(name: "content-length"), "10") + } + ) + // We're going to immediately send the response head and end. + let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) + XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.head(responseHead))) + embedded.read() + + // Send the end and confirm the connection is still live. + XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.end(nil))) + XCTAssertEqual(testUtils.connectionDelegate.hitConnectionClosed, 0) + XCTAssertEqual(testUtils.connectionDelegate.hitConnectionReleased, 0) + + // Ok, now we can process some reads. We expect 5 reads, but we do _not_ expect an .end, because + // the `FailEndHandler` is going to fail it. + embedded.embeddedEventLoop.run() + XCTAssertEqual(testWriter.written, 5) + for _ in 0..<5 { + XCTAssertNoThrow( + try embedded.receiveBodyAndVerify { + XCTAssertEqual($0.readableBytes, 2) + } + ) + } + + embedded.embeddedEventLoop.run() + XCTAssertNil(try embedded.readOutbound(as: HTTPClientRequestPart.self)) + + // We should have seen the connection close, and the request is complete. + XCTAssertEqual(testUtils.connectionDelegate.hitConnectionClosed, 1) + XCTAssertEqual(testUtils.connectionDelegate.hitConnectionReleased, 0) + + XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { error in + XCTAssertTrue(error is FailEndHandler.Error) + } + } + + func testChannelBecomesNonWritableDuringHeaderWrite() throws { + final class ChangeWritabilityOnFlush: ChannelOutboundHandler { + typealias OutboundIn = Any + func flush(context: ChannelHandlerContext) { + context.flush() + (context.channel as! EmbeddedChannel).isWritable = false + context.fireChannelWritabilityChanged() + } + } + let eventLoopGroup = EmbeddedEventLoopGroup(loops: 1) + let eventLoop = eventLoopGroup.next() as! EmbeddedEventLoop + let handler = HTTP1ClientChannelHandler( + eventLoop: eventLoop, + backgroundLogger: Logger(label: "no-op", factory: SwiftLogNoOpLogHandler.init), + connectionIdLoggerMetadata: "test connection" + ) + let channel = EmbeddedChannel( + handlers: [ + ChangeWritabilityOnFlush(), + handler, + ], + loop: eventLoop + ) + try channel.connect(to: .init(ipAddress: "127.0.0.1", port: 80)).wait() + + // non empty body is important to trigger this bug as we otherwise finish the request in a single flush + let request = MockHTTPExecutableRequest( + framingMetadata: RequestFramingMetadata(connectionClose: false, body: .fixedSize(1)), + raiseErrorIfUnimplementedMethodIsCalled: false + ) + channel.writeAndFlush(request, promise: nil) + XCTAssertEqual(request.events.map(\.kind), [.willExecuteRequest, .requestHeadSent]) + } + + func testIdleWriteTimeoutOutsideOfRunningState() { + let embedded = EmbeddedChannel() + var maybeTestUtils: HTTP1TestTools? + XCTAssertNoThrow(maybeTestUtils = try embedded.setupHTTP1Connection()) + print("pipeline", embedded.pipeline) + guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/")) + guard var request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + // start a request stream we'll never write to + let streamPromise = embedded.eventLoop.makePromise(of: Void.self) + let streamCallback = { @Sendable (streamWriter: HTTPClient.Body.StreamWriter) -> EventLoopFuture in + streamPromise.futureResult + } + request.body = .init(contentLength: nil, stream: streamCallback) + + let accumulator = ResponseAccumulator(request: request) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests( + idleReadTimeout: .milliseconds(10), + idleWriteTimeout: .milliseconds(2) + ), + delegate: accumulator + ) + ) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + testUtils.connection.executeRequest(requestBag) + + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + } + ) + + // close the pipeline to simulate a server-side close + // note this happens before we write so the idle write timeout is still running + try! embedded.pipeline.close().wait() + + // advance time to trigger the idle write timeout + // and ensure that the state machine can tolerate this + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(250)) + } } -class TestBackpressureWriter { +final class TestBackpressureWriter: Sendable { let eventLoop: EventLoop let parts: Int var finishFuture: EventLoopFuture { self.finishPromise.futureResult } private let finishPromise: EventLoopPromise - private(set) var written: Int = 0 - private var channelIsWritable: Bool = false + private struct State { + var written = 0 + var channelIsWritable = false + } + + var written: Int { + self.state.value.written + } + + private let state: NIOLoopBoundBox init(eventLoop: EventLoop, parts: Int) { self.eventLoop = eventLoop self.parts = parts - + self.state = .makeBoxSendingValue(State(), eventLoop: eventLoop) self.finishPromise = eventLoop.makePromise(of: Void.self) } - func start(writer: HTTPClient.Body.StreamWriter) -> EventLoopFuture { + func start(writer: HTTPClient.Body.StreamWriter, expectedErrors: [HTTPClientError] = []) -> EventLoopFuture { + @Sendable func recursive() { XCTAssert(self.eventLoop.inEventLoop) - XCTAssert(self.channelIsWritable) - if self.written == self.parts { + XCTAssert(self.state.value.channelIsWritable) + if self.state.value.written == self.parts { self.finishPromise.succeed(()) } else { self.eventLoop.execute { let future = writer.write(.byteBuffer(.init(bytes: [0, 1]))) - self.written += 1 + self.state.value.written += 1 future.whenComplete { result in switch result { case .success: recursive() case .failure(let error): - XCTFail("Unexpected error: \(error)") + let isExpectedError = expectedErrors.contains { httpError in + if let castError = error as? HTTPClientError { + return castError == httpError + } + return false + } + if !isExpectedError { + XCTFail("Unexpected error: \(error)") + } } } } @@ -505,14 +962,14 @@ class TestBackpressureWriter { } func writabilityChanged(_ newValue: Bool) { - self.channelIsWritable = newValue + self.state.value.channelIsWritable = newValue } } -class ResponseBackpressureDelegate: HTTPClientResponseDelegate { +final class ResponseBackpressureDelegate: HTTPClientResponseDelegate { typealias Response = Void - enum State { + enum State: Sendable { case consuming(EventLoopPromise) case waitingForRemote(CircularBuffer>) case buffering((ByteBuffer?, EventLoopPromise)?) @@ -520,40 +977,42 @@ class ResponseBackpressureDelegate: HTTPClientResponseDelegate { } let eventLoop: EventLoop - private var state: State = .buffering(nil) + private let state: NIOLoopBoundBox init(eventLoop: EventLoop) { self.eventLoop = eventLoop - - self.state = .consuming(self.eventLoop.makePromise(of: Void.self)) + self.state = .makeBoxSendingValue(.consuming(eventLoop.makePromise(of: Void.self)), eventLoop: eventLoop) } func next() -> EventLoopFuture { - switch self.state { + switch self.state.value { case .consuming(let backpressurePromise): var promiseBuffer = CircularBuffer>() let newPromise = self.eventLoop.makePromise(of: ByteBuffer?.self) promiseBuffer.append(newPromise) - self.state = .waitingForRemote(promiseBuffer) + self.state.value = .waitingForRemote(promiseBuffer) backpressurePromise.succeed(()) return newPromise.futureResult case .waitingForRemote(var promiseBuffer): - assert(!promiseBuffer.isEmpty, "assert expected to be waiting if we have at least one promise in the buffer") + assert( + !promiseBuffer.isEmpty, + "assert expected to be waiting if we have at least one promise in the buffer" + ) let promise = self.eventLoop.makePromise(of: ByteBuffer?.self) promiseBuffer.append(promise) - self.state = .waitingForRemote(promiseBuffer) + self.state.value = .waitingForRemote(promiseBuffer) return promise.futureResult case .buffering(.none): var promiseBuffer = CircularBuffer>() let promise = self.eventLoop.makePromise(of: ByteBuffer?.self) promiseBuffer.append(promise) - self.state = .waitingForRemote(promiseBuffer) + self.state.value = .waitingForRemote(promiseBuffer) return promise.futureResult case .buffering(.some((let buffer, let promise))): - self.state = .buffering(nil) + self.state.value = .buffering(nil) promise.succeed(()) return self.eventLoop.makeSucceededFuture(buffer) @@ -563,7 +1022,7 @@ class ResponseBackpressureDelegate: HTTPClientResponseDelegate { } func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { - switch self.state { + switch self.state.value { case .consuming(let backpressurePromise): return backpressurePromise.futureResult @@ -576,28 +1035,33 @@ class ResponseBackpressureDelegate: HTTPClientResponseDelegate { } func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { - switch self.state { + switch self.state.value { case .waitingForRemote(var promiseBuffer): - assert(!promiseBuffer.isEmpty, "assert expected to be waiting if we have at least one promise in the buffer") + assert( + !promiseBuffer.isEmpty, + "assert expected to be waiting if we have at least one promise in the buffer" + ) let promise = promiseBuffer.removeFirst() if promiseBuffer.isEmpty { let newBackpressurePromise = self.eventLoop.makePromise(of: Void.self) - self.state = .consuming(newBackpressurePromise) + self.state.value = .consuming(newBackpressurePromise) promise.succeed(buffer) return newBackpressurePromise.futureResult } else { - self.state = .waitingForRemote(promiseBuffer) + self.state.value = .waitingForRemote(promiseBuffer) promise.succeed(buffer) return self.eventLoop.makeSucceededVoidFuture() } case .buffering(.none): let promise = self.eventLoop.makePromise(of: Void.self) - self.state = .buffering((buffer, promise)) + self.state.value = .buffering((buffer, promise)) return promise.futureResult case .buffering(.some): - preconditionFailure("Did receive response part should not be called, before the previous promise was succeeded.") + preconditionFailure( + "Did receive response part should not be called, before the previous promise was succeeded." + ) case .done, .consuming: preconditionFailure("Invalid state: \(self.state)") @@ -605,21 +1069,23 @@ class ResponseBackpressureDelegate: HTTPClientResponseDelegate { } func didFinishRequest(task: HTTPClient.Task) throws { - switch self.state { + switch self.state.value { case .waitingForRemote(let promiseBuffer): - promiseBuffer.forEach { - $0.succeed(.none) + for promise in promiseBuffer { + promise.succeed(.none) } - self.state = .done + self.state.value = .done case .buffering(.none): - self.state = .done + self.state.value = .done case .done, .consuming: preconditionFailure("Invalid state: \(self.state)") case .buffering(.some): - preconditionFailure("Did receive response part should not be called, before the previous promise was succeeded.") + preconditionFailure( + "Did receive response part should not be called, before the previous promise was succeeded." + ) } } } @@ -636,3 +1102,19 @@ class ReadEventHitHandler: ChannelOutboundHandler { context.read() } } + +final class FailEndHandler: ChannelOutboundHandler, Sendable { + typealias OutboundIn = HTTPClientRequestPart + typealias OutboundOut = HTTPClientRequestPart + + struct Error: Swift.Error {} + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + if case .end = self.unwrapOutboundIn(data) { + // We fail this. + promise?.fail(Self.Error()) + } else { + context.write(data, promise: promise) + } + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests+XCTest.swift deleted file mode 100644 index 76a37936b..000000000 --- a/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests+XCTest.swift +++ /dev/null @@ -1,48 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTP1ConnectionStateMachineTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTP1ConnectionStateMachineTests { - static var allTests: [(String, (HTTP1ConnectionStateMachineTests) -> () throws -> Void)] { - return [ - ("testPOSTRequestWithWriteAndReadBackpressure", testPOSTRequestWithWriteAndReadBackpressure), - ("testResponseReadingWithBackpressure", testResponseReadingWithBackpressure), - ("testAConnectionCloseHeaderInTheRequestLeadsToConnectionCloseAfterRequest", testAConnectionCloseHeaderInTheRequestLeadsToConnectionCloseAfterRequest), - ("testAHTTP1_0ResponseWithoutKeepAliveHeaderLeadsToConnectionCloseAfterRequest", testAHTTP1_0ResponseWithoutKeepAliveHeaderLeadsToConnectionCloseAfterRequest), - ("testAHTTP1_0ResponseWithKeepAliveHeaderLeadsToConnectionBeingKeptAlive", testAHTTP1_0ResponseWithKeepAliveHeaderLeadsToConnectionBeingKeptAlive), - ("testAConnectionCloseHeaderInTheResponseLeadsToConnectionCloseAfterRequest", testAConnectionCloseHeaderInTheResponseLeadsToConnectionCloseAfterRequest), - ("testNIOTriggersChannelActiveTwice", testNIOTriggersChannelActiveTwice), - ("testIdleConnectionBecomesInactive", testIdleConnectionBecomesInactive), - ("testConnectionGoesAwayWhileInRequest", testConnectionGoesAwayWhileInRequest), - ("testRequestWasCancelledWhileUploadingData", testRequestWasCancelledWhileUploadingData), - ("testCancelRequestIsIgnoredWhenConnectionIsIdle", testCancelRequestIsIgnoredWhenConnectionIsIdle), - ("testReadsAreForwardedIfConnectionIsClosing", testReadsAreForwardedIfConnectionIsClosing), - ("testChannelReadsAreIgnoredIfConnectionIsClosing", testChannelReadsAreIgnoredIfConnectionIsClosing), - ("testRequestIsCancelledWhileWaitingForWritable", testRequestIsCancelledWhileWaitingForWritable), - ("testConnectionIsClosedIfErrorHappensWhileInRequest", testConnectionIsClosedIfErrorHappensWhileInRequest), - ("testConnectionIsClosedAfterSwitchingProtocols", testConnectionIsClosedAfterSwitchingProtocols), - ("testWeDontCrashAfterEarlyHintsAndConnectionClose", testWeDontCrashAfterEarlyHintsAndConnectionClose), - ("testWeDontCrashInRaceBetweenSchedulingNewRequestAndConnectionClose", testWeDontCrashInRaceBetweenSchedulingNewRequestAndConnectionClose), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift b/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift index c8ad3d510..1c6e9659f 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift @@ -12,12 +12,13 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOHTTP1 import NIOHTTPCompression import XCTest +@testable import AsyncHTTPClient + class HTTP1ConnectionStateMachineTests: XCTestCase { func testPOSTRequestWithWriteAndReadBackpressure() { var state = HTTP1ConnectionStateMachine() @@ -26,31 +27,38 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: ["content-length": "4"]) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) XCTAssertEqual(state.runNewRequest(head: requestHead, metadata: metadata), .wait) - XCTAssertEqual(state.writabilityChanged(writable: true), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual(state.writabilityChanged(writable: true), .sendRequestHead(requestHead, sendEnd: false)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: true, startIdleTimer: false) + ) let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0])) let part1 = IOData.byteBuffer(ByteBuffer(bytes: [1])) let part2 = IOData.byteBuffer(ByteBuffer(bytes: [2])) let part3 = IOData.byteBuffer(ByteBuffer(bytes: [3])) - XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) - XCTAssertEqual(state.requestStreamPartReceived(part1), .sendBodyPart(part1)) + XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) + XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) // oh the channel reports... we should slow down producing... XCTAssertEqual(state.writabilityChanged(writable: false), .pauseRequestBodyStream) // but we issued a .produceMoreRequestBodyData before... Thus, we must accept more produced // data - XCTAssertEqual(state.requestStreamPartReceived(part2), .sendBodyPart(part2)) + XCTAssertEqual(state.requestStreamPartReceived(part2, promise: nil), .sendBodyPart(part2, nil)) // however when we have put the data on the channel, we should not issue further // .produceMoreRequestBodyData events // once we receive a writable event again, we can allow the producer to produce more data XCTAssertEqual(state.writabilityChanged(writable: true), .resumeRequestBodyStream) - XCTAssertEqual(state.requestStreamPartReceived(part3), .sendBodyPart(part3)) - XCTAssertEqual(state.requestStreamFinished(), .sendRequestEnd) + XCTAssertEqual(state.requestStreamPartReceived(part3, promise: nil), .sendBodyPart(part3, nil)) + XCTAssertEqual(state.requestStreamFinished(promise: nil), .sendRequestEnd(nil)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.informConnectionIsIdle, .init([responseBody]))) @@ -64,10 +72,17 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) - XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["content-length": "12"]) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let part0 = ByteBuffer(bytes: 0...3) let part1 = ByteBuffer(bytes: 4...7) let part2 = ByteBuffer(bytes: 8...11) @@ -86,16 +101,43 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { XCTAssertEqual(state.read(), .read) } + func testWriteTimeoutAfterErrorDoesntCrash() { + var state = HTTP1ConnectionStateMachine() + XCTAssertEqual(state.channelActive(isWritable: true), .fireChannelActive) + + let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") + let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) + let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) + ) + + struct MyError: Error, Equatable {} + XCTAssertEqual(state.errorHappened(MyError()), .failRequest(MyError(), .close(nil))) + + // Primarily we care that we don't crash here + XCTAssertEqual(state.idleWriteTimeoutTriggered(), .wait) + } + func testAConnectionCloseHeaderInTheRequestLeadsToConnectionCloseAfterRequest() { var state = HTTP1ConnectionStateMachine() XCTAssertEqual(state.channelActive(isWritable: true), .fireChannelActive) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/", headers: ["connection": "close"]) let metadata = RequestFramingMetadata(connectionClose: true, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) - XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, .init([responseBody]))) @@ -108,10 +150,17 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) - XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) + ) let responseHead = HTTPResponseHead(version: .http1_0, status: .ok, headers: ["content-length": "4"]) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, .init([responseBody]))) @@ -124,10 +173,21 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) - XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, startBody: false)) - - let responseHead = HTTPResponseHead(version: .http1_0, status: .ok, headers: ["content-length": "4", "connection": "keep-alive"]) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) + ) + + let responseHead = HTTPResponseHead( + version: .http1_0, + status: .ok, + headers: ["content-length": "4", "connection": "keep-alive"] + ) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.informConnectionIsIdle, .init([responseBody]))) @@ -141,10 +201,17 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) - XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["connection": "close"]) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, .init([responseBody]))) @@ -170,9 +237,11 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) - XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) XCTAssertEqual(state.channelInactive(), .failRequest(HTTPClientError.remoteConnectionClosed, .none)) + + XCTAssertEqual(state.headSent(), .wait) } func testRequestWasCancelledWhileUploadingData() { @@ -182,13 +251,33 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: ["content-length": "4"]) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) XCTAssertEqual(state.runNewRequest(head: requestHead, metadata: metadata), .wait) - XCTAssertEqual(state.writabilityChanged(writable: true), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual(state.writabilityChanged(writable: true), .sendRequestHead(requestHead, sendEnd: false)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: true, startIdleTimer: false) + ) let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0])) let part1 = IOData.byteBuffer(ByteBuffer(bytes: [1])) - XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) - XCTAssertEqual(state.requestStreamPartReceived(part1), .sendBodyPart(part1)) - XCTAssertEqual(state.requestCancelled(closeConnection: false), .failRequest(HTTPClientError.cancelled, .close)) + XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) + XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) + XCTAssertEqual( + state.requestCancelled(closeConnection: false), + .failRequest(HTTPClientError.cancelled, .close(nil)) + ) + } + + func testNewRequestAfterErrorHappened() { + var state = HTTP1ConnectionStateMachine() + XCTAssertEqual(state.channelActive(isWritable: false), .fireChannelActive) + struct MyError: Error, Equatable {} + XCTAssertEqual(state.errorHappened(MyError()), .fireChannelError(MyError(), closeConnection: true)) + let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: ["content-length": "4"]) + let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) + let action = state.runNewRequest(head: requestHead, metadata: metadata) + guard case .failRequest = action else { + return XCTFail("unexpected action \(action)") + } } func testCancelRequestIsIgnoredWhenConnectionIsIdle() { @@ -196,9 +285,17 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { XCTAssertEqual(state.channelActive(isWritable: true), .fireChannelActive) XCTAssertEqual(state.requestCancelled(closeConnection: false), .wait, "Should be ignored.") XCTAssertEqual(state.requestCancelled(closeConnection: true), .close, "Should lead to connection closure.") - XCTAssertEqual(state.requestCancelled(closeConnection: true), .wait, "Should be ignored. Connection is already closing") + XCTAssertEqual( + state.requestCancelled(closeConnection: true), + .wait, + "Should be ignored. Connection is already closing" + ) XCTAssertEqual(state.channelInactive(), .fireChannelInactive) - XCTAssertEqual(state.requestCancelled(closeConnection: true), .wait, "Should be ignored. Connection is already closed") + XCTAssertEqual( + state.requestCancelled(closeConnection: true), + .wait, + "Should be ignored. Connection is already closed" + ) } func testReadsAreForwardedIfConnectionIsClosing() { @@ -226,7 +323,10 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: ["content-length": "4"]) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) XCTAssertEqual(state.runNewRequest(head: requestHead, metadata: metadata), .wait) - XCTAssertEqual(state.requestCancelled(closeConnection: false), .failRequest(HTTPClientError.cancelled, .informConnectionIsIdle)) + XCTAssertEqual( + state.requestCancelled(closeConnection: false), + .failRequest(HTTPClientError.cancelled, .informConnectionIsIdle) + ) } func testConnectionIsClosedIfErrorHappensWhileInRequest() { @@ -235,13 +335,20 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) - XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.channelRead(.body(ByteBuffer(string: "Hello world!\n"))), .wait) XCTAssertEqual(state.channelRead(.body(ByteBuffer(string: "Foo Bar!\n"))), .wait) let decompressionError = NIOHTTPDecompression.DecompressionError.limit - XCTAssertEqual(state.errorHappened(decompressionError), .failRequest(decompressionError, .close)) + XCTAssertEqual(state.errorHappened(decompressionError), .failRequest(decompressionError, .close(nil))) } func testConnectionIsClosedAfterSwitchingProtocols() { @@ -250,9 +357,16 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) - XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .switchingProtocols) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, [])) } @@ -262,8 +376,15 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) - XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, startBody: false)) - let responseHead = HTTPResponseHead(version: .http1_1, status: .init(statusCode: 103, reasonPhrase: "Early Hints")) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) + ) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .init(statusCode: 103, reasonPhrase: "Early Hints") + ) XCTAssertEqual(state.channelRead(.head(responseHead)), .wait) XCTAssertEqual(state.channelInactive(), .failRequest(HTTPClientError.remoteConnectionClosed, .none)) } @@ -291,12 +412,20 @@ extension HTTP1ConnectionStateMachine.Action: Equatable { case (.fireChannelInactive, .fireChannelInactive): return true + case (.fireChannelError(_, let lhsCloseConnection), .fireChannelError(_, let rhsCloseConnection)): + return lhsCloseConnection == rhsCloseConnection case (.sendRequestHead(let lhsHead, let lhsStartBody), .sendRequestHead(let rhsHead, let rhsStartBody)): return lhsHead == rhsHead && lhsStartBody == rhsStartBody - case (.sendBodyPart(let lhsData), .sendBodyPart(let rhsData)): - return lhsData == rhsData + case ( + .notifyRequestHeadSendSuccessfully(let lhsResumeRequestBodyStream, let lhsStartIdleTimer), + .notifyRequestHeadSendSuccessfully(let rhsResumeRequestBodyStream, let rhsStartIdleTimer) + ): + return lhsResumeRequestBodyStream == rhsResumeRequestBodyStream && lhsStartIdleTimer == rhsStartIdleTimer + + case (.sendBodyPart(let lhsData, let lhsPromise), .sendBodyPart(let rhsData, let rhsPromise)): + return lhsData == rhsData && lhsPromise?.futureResult == rhsPromise?.futureResult case (.sendRequestEnd, .sendRequestEnd): return true @@ -306,13 +435,19 @@ extension HTTP1ConnectionStateMachine.Action: Equatable { case (.resumeRequestBodyStream, .resumeRequestBodyStream): return true - case (.forwardResponseHead(let lhsHead, let lhsPauseRequestBodyStream), .forwardResponseHead(let rhsHead, let rhsPauseRequestBodyStream)): + case ( + .forwardResponseHead(let lhsHead, let lhsPauseRequestBodyStream), + .forwardResponseHead(let rhsHead, let rhsPauseRequestBodyStream) + ): return lhsHead == rhsHead && lhsPauseRequestBodyStream == rhsPauseRequestBodyStream case (.forwardResponseBodyParts(let lhsData), .forwardResponseBodyParts(let rhsData)): return lhsData == rhsData - case (.succeedRequest(let lhsFinalAction, let lhsFinalBuffer), .succeedRequest(let rhsFinalAction, let rhsFinalBuffer)): + case ( + .succeedRequest(let lhsFinalAction, let lhsFinalBuffer), + .succeedRequest(let rhsFinalAction, let rhsFinalBuffer) + ): return lhsFinalAction == rhsFinalAction && lhsFinalBuffer == rhsFinalBuffer case (.failRequest(_, let lhsFinalAction), .failRequest(_, let rhsFinalAction)): @@ -332,3 +467,42 @@ extension HTTP1ConnectionStateMachine.Action: Equatable { } } } + +extension HTTP1ConnectionStateMachine.Action.FinalSuccessfulStreamAction: Equatable { + public static func == ( + lhs: HTTP1ConnectionStateMachine.Action.FinalSuccessfulStreamAction, + rhs: HTTP1ConnectionStateMachine.Action.FinalSuccessfulStreamAction + ) -> Bool { + switch (lhs, rhs) { + case (.close, .close): + return true + case (sendRequestEnd(let lhsPromise, let lhsShouldClose), sendRequestEnd(let rhsPromise, let rhsShouldClose)): + return lhsPromise?.futureResult == rhsPromise?.futureResult && lhsShouldClose == rhsShouldClose + case (informConnectionIsIdle, informConnectionIsIdle): + return true + default: + return false + } + } +} + +extension HTTP1ConnectionStateMachine.Action.FinalFailedStreamAction: Equatable { + public static func == ( + lhs: HTTP1ConnectionStateMachine.Action.FinalFailedStreamAction, + rhs: HTTP1ConnectionStateMachine.Action.FinalFailedStreamAction + ) -> Bool { + switch (lhs, rhs) { + case (.close(let lhsPromise), .close(let rhsPromise)): + return lhsPromise?.futureResult == rhsPromise?.futureResult + case (.informConnectionIsIdle, .informConnectionIsIdle): + return true + case (.failWritePromise(let lhsPromise), .failWritePromise(let rhsPromise)): + return lhsPromise?.futureResult == rhsPromise?.futureResult + case (.none, .none): + return true + + default: + return false + } + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests+XCTest.swift deleted file mode 100644 index 95b3e5dac..000000000 --- a/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests+XCTest.swift +++ /dev/null @@ -1,42 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTP1ConnectionTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTP1ConnectionTests { - static var allTests: [(String, (HTTP1ConnectionTests) -> () throws -> Void)] { - return [ - ("testCreateNewConnectionWithDecompression", testCreateNewConnectionWithDecompression), - ("testCreateNewConnectionWithoutDecompression", testCreateNewConnectionWithoutDecompression), - ("testCreateNewConnectionFailureClosedIO", testCreateNewConnectionFailureClosedIO), - ("testGETRequest", testGETRequest), - ("testConnectionClosesOnCloseHeader", testConnectionClosesOnCloseHeader), - ("testConnectionClosesOnRandomlyAppearingCloseHeader", testConnectionClosesOnRandomlyAppearingCloseHeader), - ("testConnectionClosesAfterTheRequestWithoutHavingSentAnCloseHeader", testConnectionClosesAfterTheRequestWithoutHavingSentAnCloseHeader), - ("testConnectionIsClosedAfterSwitchingProtocols", testConnectionIsClosedAfterSwitchingProtocols), - ("testConnectionDropAfterEarlyHints", testConnectionDropAfterEarlyHints), - ("testConnectionIsClosedIfResponseIsReceivedBeforeRequest", testConnectionIsClosedIfResponseIsReceivedBeforeRequest), - ("testDoubleHTTPResponseLine", testDoubleHTTPResponseLine), - ("testDownloadStreamingBackpressure", testDownloadStreamingBackpressure), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests.swift b/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests.swift index 3575a6080..53001b64b 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests.swift @@ -12,7 +12,6 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOConcurrencyHelpers import NIOCore @@ -23,6 +22,8 @@ import NIOPosix import NIOTestUtils import XCTest +@testable import AsyncHTTPClient + class HTTP1ConnectionTests: XCTestCase { func testCreateNewConnectionWithDecompression() { let embedded = EmbeddedChannel() @@ -31,19 +32,23 @@ class HTTP1ConnectionTests: XCTestCase { XCTAssertNoThrow(try embedded.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 3000)).wait()) var connection: HTTP1Connection? - XCTAssertNoThrow(connection = try HTTP1Connection.start( - channel: embedded, - connectionID: 0, - delegate: MockHTTP1ConnectionDelegate(), - configuration: .init(decompression: .enabled(limit: .ratio(4))), - logger: logger - )) + XCTAssertNoThrow( + connection = try HTTP1Connection.start( + channel: embedded, + connectionID: 0, + delegate: MockHTTP1ConnectionDelegate(), + decompression: .enabled(limit: .ratio(4)), + logger: logger + ) + ) XCTAssertNotNil(try embedded.pipeline.syncOperations.handler(type: HTTPRequestEncoder.self)) - XCTAssertNotNil(try embedded.pipeline.syncOperations.handler(type: ByteToMessageHandler.self)) + XCTAssertNotNil( + try embedded.pipeline.syncOperations.handler(type: ByteToMessageHandler.self) + ) XCTAssertNotNil(try embedded.pipeline.syncOperations.handler(type: NIOHTTPResponseDecompressor.self)) - XCTAssertNoThrow(try connection?.close().wait()) + XCTAssertNoThrow(try connection?.sendableView.close().wait()) embedded.embeddedEventLoop.run() XCTAssert(!embedded.isActive) } @@ -54,17 +59,22 @@ class HTTP1ConnectionTests: XCTestCase { XCTAssertNoThrow(try embedded.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 3000)).wait()) - XCTAssertNoThrow(try HTTP1Connection.start( - channel: embedded, - connectionID: 0, - delegate: MockHTTP1ConnectionDelegate(), - configuration: .init(decompression: .disabled), - logger: logger - )) + XCTAssertNoThrow( + try HTTP1Connection.start( + channel: embedded, + connectionID: 0, + delegate: MockHTTP1ConnectionDelegate(), + decompression: .disabled, + logger: logger + ) + ) XCTAssertNotNil(try embedded.pipeline.syncOperations.handler(type: HTTPRequestEncoder.self)) - XCTAssertNotNil(try embedded.pipeline.syncOperations.handler(type: ByteToMessageHandler.self)) - XCTAssertThrowsError(try embedded.pipeline.syncOperations.handler(type: NIOHTTPResponseDecompressor.self)) { error in + XCTAssertNotNil( + try embedded.pipeline.syncOperations.handler(type: ByteToMessageHandler.self) + ) + XCTAssertThrowsError(try embedded.pipeline.syncOperations.handler(type: NIOHTTPResponseDecompressor.self)) { + error in XCTAssertEqual(error as? ChannelPipelineError, .notFound) } } @@ -78,13 +88,15 @@ class HTTP1ConnectionTests: XCTestCase { embedded.embeddedEventLoop.run() let logger = Logger(label: "test.http1.connection") - XCTAssertThrowsError(try HTTP1Connection.start( - channel: embedded, - connectionID: 0, - delegate: MockHTTP1ConnectionDelegate(), - configuration: .init(), - logger: logger - )) + XCTAssertThrowsError( + try HTTP1Connection.start( + channel: embedded, + connectionID: 0, + delegate: MockHTTP1ConnectionDelegate(), + decompression: .disabled, + logger: logger + ) + ) } func testGETRequest() { @@ -96,8 +108,7 @@ class HTTP1ConnectionTests: XCTestCase { defer { XCTAssertNoThrow(try server.stop()) } let logger = Logger(label: "test") - let delegate = MockHTTP1ConnectionDelegate() - delegate.closePromise = clientEL.makePromise(of: Void.self) + let delegate = MockHTTP1ConnectionDelegate(closePromise: clientEL.makePromise()) let connection = try! ClientBootstrap(group: clientEL) .connect(to: .init(ipAddress: "127.0.0.1", port: server.serverPort)) @@ -106,37 +117,39 @@ class HTTP1ConnectionTests: XCTestCase { channel: $0, connectionID: 0, delegate: delegate, - configuration: .init(decompression: .disabled), + decompression: .disabled, logger: logger - ) + ).sendableView } .wait() var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request( - url: "http://localhost/hello/swift", - method: .POST, - body: .stream(length: 4) { writer -> EventLoopFuture in - func recursive(count: UInt8, promise: EventLoopPromise) { - guard count < 4 else { - return promise.succeed(()) - } + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/hello/swift", + method: .POST, + body: .stream(contentLength: 4) { writer -> EventLoopFuture in + @Sendable func recursive(count: UInt8, promise: EventLoopPromise) { + guard count < 4 else { + return promise.succeed(()) + } - writer.write(.byteBuffer(ByteBuffer(bytes: [count]))).whenComplete { result in - switch result { - case .failure(let error): - XCTFail("Unexpected error: \(error)") - case .success: - recursive(count: count + 1, promise: promise) + writer.write(.byteBuffer(ByteBuffer(bytes: [count]))).whenComplete { result in + switch result { + case .failure(let error): + XCTFail("Unexpected error: \(error)") + case .success: + recursive(count: count + 1, promise: promise) + } } } - } - let promise = clientEL.makePromise(of: Void.self) - recursive(count: 0, promise: promise) - return promise.futureResult - } - )) + let promise = clientEL.makePromise(of: Void.self) + recursive(count: 0, promise: promise) + return promise.futureResult + } + ) + ) guard let request = maybeRequest else { return XCTFail("Expected to have a connection and a request") @@ -145,33 +158,39 @@ class HTTP1ConnectionTests: XCTestCase { let task = HTTPClient.Task(eventLoop: clientEL, logger: logger) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: clientEL), - task: task, - redirectHandler: nil, - connectionDeadline: .now() + .seconds(60), - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: request) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: clientEL), + task: task, + redirectHandler: nil, + connectionDeadline: .now() + .seconds(60), + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: request) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } connection.executeRequest(requestBag) - XCTAssertNoThrow(try server.receiveHeadAndVerify { head in - XCTAssertEqual(head.method, .POST) - XCTAssertEqual(head.uri, "/hello/swift") - XCTAssertEqual(head.headers["content-length"].first, "4") - }) + XCTAssertNoThrow( + try server.receiveHeadAndVerify { head in + XCTAssertEqual(head.method, .POST) + XCTAssertEqual(head.uri, "/hello/swift") + XCTAssertEqual(head.headers["content-length"].first, "4") + } + ) var received: UInt8 = 0 while received < 4 { - XCTAssertNoThrow(try server.receiveBodyAndVerify { body in - var body = body - while let read = body.readInteger(as: UInt8.self) { - XCTAssertEqual(received, read) - received += 1 + XCTAssertNoThrow( + try server.receiveBodyAndVerify { body in + var body = body + while let read = body.readInteger(as: UInt8.self) { + XCTAssertEqual(received, read) + received += 1 + } } - }) + ) } XCTAssertEqual(received, 4) XCTAssertNoThrow(try server.receiveEnd()) @@ -198,17 +217,23 @@ class HTTP1ConnectionTests: XCTestCase { var maybeChannel: Channel? - XCTAssertNoThrow(maybeChannel = try ClientBootstrap(group: eventLoop).connect(host: "localhost", port: httpBin.port).wait()) + XCTAssertNoThrow( + maybeChannel = try ClientBootstrap(group: eventLoop).connect(host: "localhost", port: httpBin.port).wait() + ) let connectionDelegate = MockConnectionDelegate() let logger = Logger(label: "test") - var maybeConnection: HTTP1Connection? - XCTAssertNoThrow(maybeConnection = try eventLoop.submit { try HTTP1Connection.start( - channel: XCTUnwrap(maybeChannel), - connectionID: 0, - delegate: connectionDelegate, - configuration: .init(), - logger: logger - ) }.wait()) + var maybeConnection: HTTP1Connection.SendableView? + XCTAssertNoThrow( + maybeConnection = try eventLoop.submit { [maybeChannel] in + try HTTP1Connection.start( + channel: XCTUnwrap(maybeChannel), + connectionID: 0, + delegate: connectionDelegate, + decompression: .disabled, + logger: logger + ).sendableView + }.wait() + ) guard let connection = maybeConnection else { return XCTFail("Expected to have a connection here") } var maybeRequest: HTTPClient.Request? @@ -217,15 +242,17 @@ class HTTP1ConnectionTests: XCTestCase { let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: eventLoopGroup.next()), - task: .init(eventLoop: eventLoopGroup.next(), logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: eventLoopGroup.next()), + task: .init(eventLoop: eventLoopGroup.next(), logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } connection.executeRequest(requestBag) @@ -248,21 +275,29 @@ class HTTP1ConnectionTests: XCTestCase { defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let closeOnRequest = (30...100).randomElement()! - let httpBin = HTTPBin(handlerFactory: { _ in SuddenlySendsCloseHeaderChannelHandler(closeOnRequest: closeOnRequest) }) + let httpBin = HTTPBin(handlerFactory: { _ in + SuddenlySendsCloseHeaderChannelHandler(closeOnRequest: closeOnRequest) + }) var maybeChannel: Channel? - XCTAssertNoThrow(maybeChannel = try ClientBootstrap(group: eventLoop).connect(host: "localhost", port: httpBin.port).wait()) + XCTAssertNoThrow( + maybeChannel = try ClientBootstrap(group: eventLoop).connect(host: "localhost", port: httpBin.port).wait() + ) let connectionDelegate = MockConnectionDelegate() let logger = Logger(label: "test") - var maybeConnection: HTTP1Connection? - XCTAssertNoThrow(maybeConnection = try eventLoop.submit { try HTTP1Connection.start( - channel: XCTUnwrap(maybeChannel), - connectionID: 0, - delegate: connectionDelegate, - configuration: .init(), - logger: logger - ) }.wait()) + var maybeConnection: HTTP1Connection.SendableView? + XCTAssertNoThrow( + maybeConnection = try eventLoop.submit { [maybeChannel] in + try HTTP1Connection.start( + channel: XCTUnwrap(maybeChannel), + connectionID: 0, + delegate: connectionDelegate, + decompression: .disabled, + logger: logger + ).sendableView + }.wait() + ) guard let connection = maybeConnection else { return XCTFail("Expected to have a connection here") } var counter = 0 @@ -275,16 +310,20 @@ class HTTP1ConnectionTests: XCTestCase { let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: eventLoopGroup.next()), - task: .init(eventLoop: eventLoopGroup.next(), logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) - guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: eventLoopGroup.next()), + task: .init(eventLoop: eventLoopGroup.next(), logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) + guard let requestBag = maybeRequestBag else { + return XCTFail("Expected to be able to create a request bag") + } connection.executeRequest(requestBag) @@ -293,7 +332,7 @@ class HTTP1ConnectionTests: XCTestCase { XCTAssertEqual(response?.status, .ok) if response?.headers.first(name: "connection") == "close" { - break // the loop + break // the loop } else { XCTAssertEqual(httpBin.activeConnections, 1) XCTAssertEqual(connectionDelegate.hitConnectionReleased, counter) @@ -306,8 +345,11 @@ class HTTP1ConnectionTests: XCTestCase { XCTAssertEqual(counter, closeOnRequest) XCTAssertEqual(connectionDelegate.hitConnectionClosed, 1) - XCTAssertEqual(connectionDelegate.hitConnectionReleased, counter - 1, - "If a close header is received connection release is not triggered.") + XCTAssertEqual( + connectionDelegate.hitConnectionReleased, + counter - 1, + "If a close header is received connection release is not triggered." + ) // we need to wait a small amount of time to see the connection close on the server try! eventLoop.scheduleTask(in: .milliseconds(200)) {}.futureResult.wait() @@ -324,17 +366,23 @@ class HTTP1ConnectionTests: XCTestCase { var maybeChannel: Channel? - XCTAssertNoThrow(maybeChannel = try ClientBootstrap(group: eventLoop).connect(host: "localhost", port: httpBin.port).wait()) + XCTAssertNoThrow( + maybeChannel = try ClientBootstrap(group: eventLoop).connect(host: "localhost", port: httpBin.port).wait() + ) let connectionDelegate = MockConnectionDelegate() let logger = Logger(label: "test") - var maybeConnection: HTTP1Connection? - XCTAssertNoThrow(maybeConnection = try eventLoop.submit { try HTTP1Connection.start( - channel: XCTUnwrap(maybeChannel), - connectionID: 0, - delegate: connectionDelegate, - configuration: .init(), - logger: logger - ) }.wait()) + var maybeConnection: HTTP1Connection.SendableView? + XCTAssertNoThrow( + maybeConnection = try eventLoop.submit { [maybeChannel] in + try HTTP1Connection.start( + channel: XCTUnwrap(maybeChannel), + connectionID: 0, + delegate: connectionDelegate, + decompression: .disabled, + logger: logger + ).sendableView + }.wait() + ) guard let connection = maybeConnection else { return XCTFail("Expected to have a connection here") } var maybeRequest: HTTPClient.Request? @@ -343,15 +391,17 @@ class HTTP1ConnectionTests: XCTestCase { let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: eventLoopGroup.next()), - task: .init(eventLoop: eventLoopGroup.next(), logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: eventLoopGroup.next()), + task: .init(eventLoop: eventLoopGroup.next(), logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } connection.executeRequest(requestBag) @@ -373,13 +423,15 @@ class HTTP1ConnectionTests: XCTestCase { var maybeConnection: HTTP1Connection? let connectionDelegate = MockConnectionDelegate() - XCTAssertNoThrow(maybeConnection = try HTTP1Connection.start( - channel: embedded, - connectionID: 0, - delegate: connectionDelegate, - configuration: .init(decompression: .enabled(limit: .ratio(4))), - logger: logger - )) + XCTAssertNoThrow( + maybeConnection = try HTTP1Connection.start( + channel: embedded, + connectionID: 0, + delegate: connectionDelegate, + decompression: .enabled(limit: .ratio(4)), + logger: logger + ) + ) guard let connection = maybeConnection else { return XCTFail("Expected to have a connection at this point.") } var maybeRequest: HTTPClient.Request? @@ -388,38 +440,40 @@ class HTTP1ConnectionTests: XCTestCase { let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } - connection.executeRequest(requestBag) + connection.sendableView.executeRequest(requestBag) - XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // head - XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // end + XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // head + XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // end let responseString = """ - HTTP/1.1 101 Switching Protocols\r\n\ - Upgrade: websocket\r\n\ - Sec-WebSocket-Accept: xAMUK7/Il9bLRFJrikq6mm8CNZI=\r\n\ - Connection: upgrade\r\n\ - date: Mon, 27 Sep 2021 17:53:14 GMT\r\n\ - \r\n\ - \r\nfoo bar baz - """ + HTTP/1.1 101 Switching Protocols\r\n\ + Upgrade: websocket\r\n\ + Sec-WebSocket-Accept: xAMUK7/Il9bLRFJrikq6mm8CNZI=\r\n\ + Connection: upgrade\r\n\ + date: Mon, 27 Sep 2021 17:53:14 GMT\r\n\ + \r\n\ + \r\nfoo bar baz + """ XCTAssertTrue(embedded.isActive) XCTAssertEqual(connectionDelegate.hitConnectionClosed, 0) XCTAssertEqual(connectionDelegate.hitConnectionReleased, 0) XCTAssertNoThrow(try embedded.writeInbound(ByteBuffer(string: responseString))) XCTAssertFalse(embedded.isActive) - (embedded.eventLoop as! EmbeddedEventLoop).run() // tick once to run futures. + (embedded.eventLoop as! EmbeddedEventLoop).run() // tick once to run futures. XCTAssertEqual(connectionDelegate.hitConnectionClosed, 1) XCTAssertEqual(connectionDelegate.hitConnectionReleased, 0) @@ -438,13 +492,15 @@ class HTTP1ConnectionTests: XCTestCase { var maybeConnection: HTTP1Connection? let connectionDelegate = MockConnectionDelegate() - XCTAssertNoThrow(maybeConnection = try HTTP1Connection.start( - channel: embedded, - connectionID: 0, - delegate: connectionDelegate, - configuration: .init(decompression: .enabled(limit: .ratio(4))), - logger: logger - )) + XCTAssertNoThrow( + maybeConnection = try HTTP1Connection.start( + channel: embedded, + connectionID: 0, + delegate: connectionDelegate, + decompression: .enabled(limit: .ratio(4)), + logger: logger + ) + ) guard let connection = maybeConnection else { return XCTFail("Expected to have a connection at this point.") } var maybeRequest: HTTPClient.Request? @@ -453,28 +509,30 @@ class HTTP1ConnectionTests: XCTestCase { let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } - connection.executeRequest(requestBag) + connection.sendableView.executeRequest(requestBag) - XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // head - XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // end + XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // head + XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // end let responseString = """ - HTTP/1.1 103 Early Hints\r\n\ - date: Mon, 27 Sep 2021 17:53:14 GMT\r\n\ - \r\n\ - \r\n - """ + HTTP/1.1 103 Early Hints\r\n\ + date: Mon, 27 Sep 2021 17:53:14 GMT\r\n\ + \r\n\ + \r\n + """ XCTAssertTrue(embedded.isActive) XCTAssertEqual(connectionDelegate.hitConnectionClosed, 0) @@ -484,7 +542,7 @@ class HTTP1ConnectionTests: XCTestCase { XCTAssertTrue(embedded.isActive, "The connection remains active after the informational response head") XCTAssertNoThrow(try embedded.close().wait(), "the connection was closed") - embedded.embeddedEventLoop.run() // tick once to run futures. + embedded.embeddedEventLoop.run() // tick once to run futures. XCTAssertEqual(connectionDelegate.hitConnectionClosed, 1) XCTAssertEqual(connectionDelegate.hitConnectionReleased, 0) @@ -500,20 +558,22 @@ class HTTP1ConnectionTests: XCTestCase { XCTAssertNoThrow(try embedded.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 0)).wait()) let connectionDelegate = MockConnectionDelegate() - XCTAssertNoThrow(try HTTP1Connection.start( - channel: embedded, - connectionID: 0, - delegate: connectionDelegate, - configuration: .init(decompression: .enabled(limit: .ratio(4))), - logger: logger - )) + XCTAssertNoThrow( + try HTTP1Connection.start( + channel: embedded, + connectionID: 0, + delegate: connectionDelegate, + decompression: .enabled(limit: .ratio(4)), + logger: logger + ) + ) let responseString = """ - HTTP/1.1 200 OK\r\n\ - date: Mon, 27 Sep 2021 17:53:14 GMT\r\n\ - \r\n\ - \r\n - """ + HTTP/1.1 200 OK\r\n\ + date: Mon, 27 Sep 2021 17:53:14 GMT\r\n\ + \r\n\ + \r\n + """ XCTAssertEqual(connectionDelegate.hitConnectionClosed, 0) XCTAssertEqual(connectionDelegate.hitConnectionReleased, 0) @@ -522,7 +582,7 @@ class HTTP1ConnectionTests: XCTestCase { XCTAssertEqual($0 as? NIOHTTPDecoderError, .unsolicitedResponse) } XCTAssertFalse(embedded.isActive) - (embedded.eventLoop as! EmbeddedEventLoop).run() // tick once to run futures. + (embedded.eventLoop as! EmbeddedEventLoop).run() // tick once to run futures. XCTAssertEqual(connectionDelegate.hitConnectionClosed, 1) XCTAssertEqual(connectionDelegate.hitConnectionReleased, 0) } @@ -535,13 +595,15 @@ class HTTP1ConnectionTests: XCTestCase { var maybeConnection: HTTP1Connection? let connectionDelegate = MockConnectionDelegate() - XCTAssertNoThrow(maybeConnection = try HTTP1Connection.start( - channel: embedded, - connectionID: 0, - delegate: connectionDelegate, - configuration: .init(decompression: .enabled(limit: .ratio(4))), - logger: logger - )) + XCTAssertNoThrow( + maybeConnection = try HTTP1Connection.start( + channel: embedded, + connectionID: 0, + delegate: connectionDelegate, + decompression: .enabled(limit: .ratio(4)), + logger: logger + ) + ) guard let connection = maybeConnection else { return XCTFail("Expected to have a connection at this point.") } var maybeRequest: HTTPClient.Request? @@ -550,32 +612,34 @@ class HTTP1ConnectionTests: XCTestCase { let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } - connection.executeRequest(requestBag) + connection.sendableView.executeRequest(requestBag) let responseString = """ - HTTP/1.0 200 OK\r\n\ - HTTP/1.0 200 OK\r\n\r\n - """ + HTTP/1.0 200 OK\r\n\ + HTTP/1.0 200 OK\r\n\r\n + """ - XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // head - XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // end + XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // head + XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // end XCTAssertEqual(connectionDelegate.hitConnectionClosed, 0) XCTAssertEqual(connectionDelegate.hitConnectionReleased, 0) XCTAssertNoThrow(try embedded.writeInbound(ByteBuffer(string: responseString))) XCTAssertFalse(embedded.isActive) - (embedded.eventLoop as! EmbeddedEventLoop).run() // tick once to run futures. + (embedded.eventLoop as! EmbeddedEventLoop).run() // tick once to run futures. XCTAssertEqual(connectionDelegate.hitConnectionClosed, 1) XCTAssertEqual(connectionDelegate.hitConnectionReleased, 0) } @@ -589,42 +653,42 @@ class HTTP1ConnectionTests: XCTestCase { // bytes a ready to be read as well. This will allow us to test if subsequent reads // are waiting for backpressure promise. func testDownloadStreamingBackpressure() { - class BackpressureTestDelegate: HTTPClientResponseDelegate { + final class BackpressureTestDelegate: HTTPClientResponseDelegate { typealias Response = Void - var _reads = 0 - var _channel: Channel? + private struct State: Sendable { + var reads = 0 + var channel: Channel? + } + + private let state = NIOLockedValueBox(State()) + + var reads: Int { + self.state.withLockedValue { $0.reads } + } - let lock: Lock let backpressurePromise: EventLoopPromise let messageReceived: EventLoopPromise init(eventLoop: EventLoop) { - self.lock = Lock() self.backpressurePromise = eventLoop.makePromise() self.messageReceived = eventLoop.makePromise() } - var reads: Int { - return self.lock.withLock { - self._reads - } - } - func willExecuteOnChannel(_ channel: Channel) { - self.lock.withLockVoid { - self._channel = channel + self.state.withLockedValue { + $0.channel = channel } } func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { - return task.futureResult.eventLoop.makeSucceededVoidFuture() + task.futureResult.eventLoop.makeSucceededVoidFuture() } func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { // We count a number of reads received. - self.lock.withLockVoid { - self._reads += 1 + self.state.withLockedValue { + $0.reads += 1 } // We need to notify the test when first byte of the message is arrived. self.messageReceived.succeed(()) @@ -656,8 +720,8 @@ class HTTP1ConnectionTests: XCTestCase { let buffer = context.channel.allocator.buffer(string: "1234") context.writeAndFlush(self.wrapOutboundOut(.body(.byteBuffer(buffer))), promise: nil) - self.endFuture.hop(to: context.eventLoop).whenSuccess { - context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + self.endFuture.hop(to: context.eventLoop).assumeIsolated().whenSuccess { + context.writeAndFlush(Self.wrapOutboundOut(.end(nil)), promise: nil) } } } @@ -679,34 +743,42 @@ class HTTP1ConnectionTests: XCTestCase { defer { XCTAssertNoThrow(try httpBin.shutdown()) } var maybeChannel: Channel? - XCTAssertNoThrow(maybeChannel = try ClientBootstrap(group: eventLoopGroup) - .channelOption(ChannelOptions.maxMessagesPerRead, value: 1) - .channelOption(ChannelOptions.recvAllocator, value: FixedSizeRecvByteBufferAllocator(capacity: 1)) - .connect(host: "localhost", port: httpBin.port) - .wait()) + XCTAssertNoThrow( + maybeChannel = try ClientBootstrap(group: eventLoopGroup) + .channelOption(ChannelOptions.maxMessagesPerRead, value: 1) + .channelOption(ChannelOptions.recvAllocator, value: FixedSizeRecvByteBufferAllocator(capacity: 1)) + .connect(host: "localhost", port: httpBin.port) + .wait() + ) guard let channel = maybeChannel else { return XCTFail("Expected to have a channel at this point") } let connectionDelegate = MockConnectionDelegate() - var maybeConnection: HTTP1Connection? - XCTAssertNoThrow(maybeConnection = try channel.eventLoop.submit { try HTTP1Connection.start( - channel: channel, - connectionID: 0, - delegate: connectionDelegate, - configuration: .init(), - logger: logger - ) }.wait()) + var maybeConnection: HTTP1Connection.SendableView? + XCTAssertNoThrow( + maybeConnection = try channel.eventLoop.submit { + try HTTP1Connection.start( + channel: channel, + connectionID: 0, + delegate: connectionDelegate, + decompression: .disabled, + logger: logger + ).sendableView + }.wait() + ) guard let connection = maybeConnection else { return XCTFail("Expected to have a connection at this point") } var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: HTTPClient.Request(url: "http://localhost:\(httpBin.port)/custom"), - eventLoopPreference: .delegate(on: requestEventLoop), - task: .init(eventLoop: requestEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: backpressureDelegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: HTTPClient.Request(url: "http://localhost:\(httpBin.port)/custom"), + eventLoopPreference: .delegate(on: requestEventLoop), + task: .init(eventLoop: requestEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: backpressureDelegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } backpressureDelegate.willExecuteOnChannel(connection.channel) @@ -729,15 +801,20 @@ class HTTP1ConnectionTests: XCTestCase { } } -class MockHTTP1ConnectionDelegate: HTTP1ConnectionDelegate { - var releasePromise: EventLoopPromise? - var closePromise: EventLoopPromise? +final class MockHTTP1ConnectionDelegate: HTTP1ConnectionDelegate { + let releasePromise: EventLoopPromise? + let closePromise: EventLoopPromise? + + init(releasePromise: EventLoopPromise? = nil, closePromise: EventLoopPromise? = nil) { + self.releasePromise = releasePromise + self.closePromise = closePromise + } - func http1ConnectionReleased(_: HTTP1Connection) { + func http1ConnectionReleased(_: HTTPConnectionPool.Connection.ID) { self.releasePromise?.succeed(()) } - func http1ConnectionClosed(_: HTTP1Connection) { + func http1ConnectionClosed(_: HTTPConnectionPool.Connection.ID) { self.closePromise?.succeed(()) } } @@ -764,7 +841,12 @@ class SuddenlySendsCloseHeaderChannelHandler: ChannelInboundHandler { break case .end: if self.closeOnRequest == self.counter { - context.write(self.wrapOutboundOut(.head(.init(version: .http1_1, status: .ok, headers: ["connection": "close"]))), promise: nil) + context.write( + self.wrapOutboundOut( + .head(.init(version: .http1_1, status: .ok, headers: ["connection": "close"])) + ), + promise: nil + ) context.write(self.wrapOutboundOut(.end(nil)), promise: nil) context.flush() self.counter += 1 @@ -797,38 +879,40 @@ class AfterRequestCloseConnectionChannelHandler: ChannelInboundHandler { context.write(self.wrapOutboundOut(.end(nil)), promise: nil) context.flush() - context.eventLoop.scheduleTask(in: .milliseconds(20)) { + context.eventLoop.assumeIsolated().scheduleTask(in: .milliseconds(20)) { context.close(promise: nil) } } } } -class MockConnectionDelegate: HTTP1ConnectionDelegate { - private var lock = Lock() +final class MockConnectionDelegate: HTTP1ConnectionDelegate { + private let counts = NIOLockedValueBox(Counts()) - private var _hitConnectionReleased = 0 - private var _hitConnectionClosed = 0 + private struct Counts: Sendable { + var hitConnectionReleased = 0 + var hitConnectionClosed = 0 + } var hitConnectionReleased: Int { - self.lock.withLock { self._hitConnectionReleased } + self.counts.withLockedValue { $0.hitConnectionReleased } } var hitConnectionClosed: Int { - self.lock.withLock { self._hitConnectionClosed } + self.counts.withLockedValue { $0.hitConnectionClosed } } init() {} - func http1ConnectionReleased(_: HTTP1Connection) { - self.lock.withLockVoid { - self._hitConnectionReleased += 1 + func http1ConnectionReleased(_: HTTPConnectionPool.Connection.ID) { + self.counts.withLockedValue { + $0.hitConnectionReleased += 1 } } - func http1ConnectionClosed(_: HTTP1Connection) { - self.lock.withLockVoid { - self._hitConnectionClosed += 1 + func http1ConnectionClosed(_: HTTPConnectionPool.Connection.ID) { + self.counts.withLockedValue { + $0.hitConnectionClosed += 1 } } } diff --git a/Tests/AsyncHTTPClientTests/HTTP1ProxyConnectHandlerTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTP1ProxyConnectHandlerTests+XCTest.swift deleted file mode 100644 index 15c432037..000000000 --- a/Tests/AsyncHTTPClientTests/HTTP1ProxyConnectHandlerTests+XCTest.swift +++ /dev/null @@ -1,35 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTP1ProxyConnectHandlerTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTP1ProxyConnectHandlerTests { - static var allTests: [(String, (HTTP1ProxyConnectHandlerTests) -> () throws -> Void)] { - return [ - ("testProxyConnectWithoutAuthorizationSuccess", testProxyConnectWithoutAuthorizationSuccess), - ("testProxyConnectWithAuthorization", testProxyConnectWithAuthorization), - ("testProxyConnectWithoutAuthorizationFailure500", testProxyConnectWithoutAuthorizationFailure500), - ("testProxyConnectWithoutAuthorizationButAuthorizationNeeded", testProxyConnectWithoutAuthorizationButAuthorizationNeeded), - ("testProxyConnectReceivesBody", testProxyConnectReceivesBody), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTP1ProxyConnectHandlerTests.swift b/Tests/AsyncHTTPClientTests/HTTP1ProxyConnectHandlerTests.swift index bbe6fab1f..d75865da2 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ProxyConnectHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ProxyConnectHandlerTests.swift @@ -12,12 +12,13 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOEmbedded import NIOHTTP1 import XCTest +@testable import AsyncHTTPClient + class HTTP1ProxyConnectHandlerTests: XCTestCase { func testProxyConnectWithoutAuthorizationSuccess() { let embedded = EmbeddedChannel() @@ -43,6 +44,7 @@ class HTTP1ProxyConnectHandlerTests: XCTestCase { XCTAssertEqual(head.method, .CONNECT) XCTAssertEqual(head.uri, "swift.org:443") + XCTAssertEqual(head.headers["host"].first, "swift.org") XCTAssertNil(head.headers["proxy-authorization"].first) XCTAssertEqual(try embedded.readOutbound(as: HTTPClientRequestPart.self), .end(nil)) @@ -76,6 +78,7 @@ class HTTP1ProxyConnectHandlerTests: XCTestCase { XCTAssertEqual(head.method, .CONNECT) XCTAssertEqual(head.uri, "swift.org:443") + XCTAssertEqual(head.headers["host"].first, "swift.org") XCTAssertEqual(head.headers["proxy-authorization"].first, "Basic abc123") XCTAssertEqual(try embedded.readOutbound(as: HTTPClientRequestPart.self), .end(nil)) @@ -109,6 +112,7 @@ class HTTP1ProxyConnectHandlerTests: XCTestCase { XCTAssertEqual(head.method, .CONNECT) XCTAssertEqual(head.uri, "swift.org:443") + XCTAssertEqual(head.headers["host"].first, "swift.org") XCTAssertNil(head.headers["proxy-authorization"].first) XCTAssertEqual(try embedded.readOutbound(as: HTTPClientRequestPart.self), .end(nil)) @@ -148,6 +152,7 @@ class HTTP1ProxyConnectHandlerTests: XCTestCase { XCTAssertEqual(head.method, .CONNECT) XCTAssertEqual(head.uri, "swift.org:443") + XCTAssertEqual(head.headers["host"].first, "swift.org") XCTAssertNil(head.headers["proxy-authorization"].first) XCTAssertEqual(try embedded.readOutbound(as: HTTPClientRequestPart.self), .end(nil)) @@ -187,6 +192,7 @@ class HTTP1ProxyConnectHandlerTests: XCTestCase { XCTAssertEqual(head.method, .CONNECT) XCTAssertEqual(head.uri, "swift.org:443") + XCTAssertEqual(head.headers["host"].first, "swift.org") XCTAssertEqual(try embedded.readOutbound(as: HTTPClientRequestPart.self), .end(nil)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) diff --git a/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests+XCTest.swift deleted file mode 100644 index 8fa219838..000000000 --- a/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests+XCTest.swift +++ /dev/null @@ -1,35 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTP2ClientRequestHandlerTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTP2ClientRequestHandlerTests { - static var allTests: [(String, (HTTP2ClientRequestHandlerTests) -> () throws -> Void)] { - return [ - ("testResponseBackpressure", testResponseBackpressure), - ("testWriteBackpressure", testWriteBackpressure), - ("testIdleReadTimeout", testIdleReadTimeout), - ("testIdleReadTimeoutIsCanceledIfRequestIsCanceled", testIdleReadTimeoutIsCanceledIfRequestIsCanceled), - ("testWriteHTTPHeadFails", testWriteHTTPHeadFails), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift b/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift index e67529ad8..71f7f3d1a 100644 --- a/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift @@ -12,13 +12,14 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOCore import NIOEmbedded import NIOHTTP1 import XCTest +@testable import AsyncHTTPClient + class HTTP2ClientRequestHandlerTests: XCTestCase { func testResponseBackpressure() { let embedded = EmbeddedChannel() @@ -34,28 +35,36 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } embedded.write(requestBag, promise: nil) XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) - XCTAssertNoThrow(try embedded.receiveHeadAndVerify { - XCTAssertEqual($0.method, .GET) - XCTAssertEqual($0.uri, "/") - XCTAssertEqual($0.headers.first(name: "host"), "localhost") - }) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + } + ) XCTAssertEqual(try embedded.readOutbound(as: HTTPClientRequestPart.self), .end(nil)) - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "12")]) + ) XCTAssertEqual(readEventHandler.readHitCounter, 0) embedded.read() @@ -115,22 +124,30 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { let testWriter = TestBackpressureWriter(eventLoop: embedded.eventLoop, parts: 50) var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/", method: .POST, body: .stream(length: 100) { writer in - testWriter.start(writer: writer) - })) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/", + method: .POST, + body: .stream(contentLength: 100) { writer in + testWriter.start(writer: writer) + } + ) + ) guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } embedded.isWritable = false @@ -143,12 +160,14 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { testWriter.writabilityChanged(true) embedded.pipeline.fireChannelWritabilityChanged() - XCTAssertNoThrow(try embedded.receiveHeadAndVerify { - XCTAssertEqual($0.method, .POST) - XCTAssertEqual($0.uri, "/") - XCTAssertEqual($0.headers.first(name: "host"), "localhost") - XCTAssertEqual($0.headers.first(name: "content-length"), "100") - }) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .POST) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + XCTAssertEqual($0.headers.first(name: "content-length"), "100") + } + ) // the next body write will be executed once we tick the el. before we make the channel // unwritable @@ -162,9 +181,11 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { embedded.embeddedEventLoop.run() - XCTAssertNoThrow(try embedded.receiveBodyAndVerify { - XCTAssertEqual($0.readableBytes, 2) - }) + XCTAssertNoThrow( + try embedded.receiveBodyAndVerify { + XCTAssertEqual($0.readableBytes, 2) + } + ) XCTAssertEqual(testWriter.written, index + 1) @@ -198,27 +219,35 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } embedded.write(requestBag, promise: nil) - XCTAssertNoThrow(try embedded.receiveHeadAndVerify { - XCTAssertEqual($0.method, .GET) - XCTAssertEqual($0.uri, "/") - XCTAssertEqual($0.headers.first(name: "host"), "localhost") - }) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + } + ) XCTAssertNoThrow(try embedded.receiveEnd()) - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "12")]) + ) XCTAssertEqual(readEventHandler.readHitCounter, 0) embedded.read() @@ -248,27 +277,35 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } embedded.write(requestBag, promise: nil) - XCTAssertNoThrow(try embedded.receiveHeadAndVerify { - XCTAssertEqual($0.method, .GET) - XCTAssertEqual($0.uri, "/") - XCTAssertEqual($0.headers.first(name: "host"), "localhost") - }) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + } + ) XCTAssertNoThrow(try embedded.receiveEnd()) - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "12")]) + ) XCTAssertEqual(readEventHandler.readHitCounter, 0) embedded.read() @@ -276,7 +313,164 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.head(responseHead))) // canceling the request - requestBag.cancel() + requestBag.fail(HTTPClientError.cancelled) + XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .cancelled) + } + + // the idle read timeout should be cleared because we canceled the request + // therefore advancing the time should not trigger a crash + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(250)) + } + + func testIdleWriteTimeout() { + let embedded = EmbeddedChannel() + let requestHandler = HTTP2ClientRequestHandler(eventLoop: embedded.eventLoop) + XCTAssertNoThrow(try embedded.pipeline.syncOperations.addHandlers([requestHandler])) + XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) + let logger = Logger(label: "test") + + let testWriter = TestBackpressureWriter(eventLoop: embedded.eventLoop, parts: 5) + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/", + method: .POST, + body: .stream(contentLength: 10) { writer in + // Advance time by more than the idle write timeout (that's 1 millisecond) to trigger the timeout. + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(2)) + return testWriter.start(writer: writer) + } + ) + ) + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleWriteTimeout: .milliseconds(1)), + delegate: delegate + ) + ) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + embedded.isWritable = true + testWriter.writabilityChanged(true) + embedded.pipeline.fireChannelWritabilityChanged() + embedded.write(requestBag, promise: nil) + + XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .writeTimeout) + } + } + + func testIdleWriteTimeoutWritabilityChanged() { + let embedded = EmbeddedChannel() + let readEventHandler = ReadEventHitHandler() + let requestHandler = HTTP2ClientRequestHandler(eventLoop: embedded.eventLoop) + XCTAssertNoThrow(try embedded.pipeline.syncOperations.addHandlers([readEventHandler, requestHandler])) + XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) + let logger = Logger(label: "test") + + let testWriter = TestBackpressureWriter(eventLoop: embedded.eventLoop, parts: 5) + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/", + method: .POST, + body: .stream(contentLength: 10) { writer in + embedded.isWritable = false + embedded.pipeline.fireChannelWritabilityChanged() + // This should not trigger any errors or timeouts, because the timer isn't running + // as the channel is not writable. + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(20)) + + // Now that the channel will become writable, this should trigger a timeout. + embedded.isWritable = true + embedded.pipeline.fireChannelWritabilityChanged() + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(2)) + + return testWriter.start(writer: writer) + } + ) + ) + + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseAccumulator(request: request) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleWriteTimeout: .milliseconds(1)), + delegate: delegate + ) + ) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + embedded.isWritable = true + testWriter.writabilityChanged(true) + embedded.pipeline.fireChannelWritabilityChanged() + embedded.write(requestBag, promise: nil) + + XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .writeTimeout) + } + } + + func testIdleWriteTimeoutIsCanceledIfRequestIsCanceled() { + let embedded = EmbeddedChannel() + let readEventHandler = ReadEventHitHandler() + let requestHandler = HTTP2ClientRequestHandler(eventLoop: embedded.eventLoop) + XCTAssertNoThrow(try embedded.pipeline.syncOperations.addHandlers([readEventHandler, requestHandler])) + XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) + let logger = Logger(label: "test") + + let testWriter = TestBackpressureWriter(eventLoop: embedded.eventLoop, parts: 5) + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/", + method: .POST, + body: .stream(contentLength: 2) { writer in + testWriter.start(writer: writer, expectedErrors: [HTTPClientError.cancelled]) + } + ) + ) + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleWriteTimeout: .milliseconds(1)), + delegate: delegate + ) + ) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + embedded.isWritable = true + testWriter.writabilityChanged(true) + embedded.pipeline.fireChannelWritabilityChanged() + embedded.write(requestBag, promise: nil) + + // canceling the request + requestBag.fail(HTTPClientError.cancelled) XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { XCTAssertEqual($0 as? HTTPClientError, .cancelled) } @@ -318,16 +512,20 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), - delegate: delegate - )) - guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), + delegate: delegate + ) + ) + guard let requestBag = maybeRequestBag else { + return XCTFail("Expected to be able to create a request bag") + } embedded.isWritable = false XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) @@ -335,6 +533,7 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { // the handler only writes once the channel is writable XCTAssertEqual(try embedded.readOutbound(as: HTTPClientRequestPart.self), .none) + XCTAssertTrue(embedded.isActive) embedded.isWritable = true embedded.pipeline.fireChannelWritabilityChanged() @@ -342,7 +541,39 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { XCTAssertEqual($0 as? WriteError, WriteError()) } - XCTAssertEqual(embedded.isActive, false) + XCTAssertFalse(embedded.isActive) + } + } + + func testChannelBecomesNonWritableDuringHeaderWrite() throws { + final class ChangeWritabilityOnFlush: ChannelOutboundHandler { + typealias OutboundIn = Any + func flush(context: ChannelHandlerContext) { + context.flush() + (context.channel as! EmbeddedChannel).isWritable = false + context.fireChannelWritabilityChanged() + } } + let eventLoopGroup = EmbeddedEventLoopGroup(loops: 1) + let eventLoop = eventLoopGroup.next() as! EmbeddedEventLoop + let handler = HTTP2ClientRequestHandler( + eventLoop: eventLoop + ) + let channel = EmbeddedChannel( + handlers: [ + ChangeWritabilityOnFlush(), + handler, + ], + loop: eventLoop + ) + try channel.connect(to: .init(ipAddress: "127.0.0.1", port: 80)).wait() + + // non empty body is important to trigger this bug as we otherwise finish the request in a single flush + let request = MockHTTPExecutableRequest( + framingMetadata: RequestFramingMetadata(connectionClose: false, body: .fixedSize(1)), + raiseErrorIfUnimplementedMethodIsCalled: false + ) + channel.writeAndFlush(request, promise: nil) + XCTAssertEqual(request.events.map(\.kind), [.willExecuteRequest, .requestHeadSent]) } } diff --git a/Tests/AsyncHTTPClientTests/HTTP2ClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTP2ClientTests+XCTest.swift deleted file mode 100644 index e7f399658..000000000 --- a/Tests/AsyncHTTPClientTests/HTTP2ClientTests+XCTest.swift +++ /dev/null @@ -1,42 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTP2ClientTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTP2ClientTests { - static var allTests: [(String, (HTTP2ClientTests) -> () throws -> Void)] { - return [ - ("testSimpleGet", testSimpleGet), - ("testStreamRequestBodyWithoutKnowledgeAboutLength", testStreamRequestBodyWithoutKnowledgeAboutLength), - ("testStreamRequestBodyWithFalseKnowledgeAboutLength", testStreamRequestBodyWithFalseKnowledgeAboutLength), - ("testConcurrentRequests", testConcurrentRequests), - ("testConcurrentRequestsFromDifferentThreads", testConcurrentRequestsFromDifferentThreads), - ("testConcurrentRequestsWorkWithRequiredEventLoop", testConcurrentRequestsWorkWithRequiredEventLoop), - ("testUncleanShutdownCancelsExecutingAndQueuedTasks", testUncleanShutdownCancelsExecutingAndQueuedTasks), - ("testCancelingRunningRequest", testCancelingRunningRequest), - ("testReadTimeout", testReadTimeout), - ("testH2CanHandleRequestsThatHaveAlreadyHitTheDeadline", testH2CanHandleRequestsThatHaveAlreadyHitTheDeadline), - ("testStressCancelingRunningRequestFromDifferentThreads", testStressCancelingRunningRequestFromDifferentThreads), - ("testPlatformConnectErrorIsForwardedOnTimeout", testPlatformConnectErrorIsForwardedOnTimeout), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTP2ClientTests.swift b/Tests/AsyncHTTPClientTests/HTTP2ClientTests.swift index eb1ac2ddc..183a227bd 100644 --- a/Tests/AsyncHTTPClientTests/HTTP2ClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP2ClientTests.swift @@ -12,20 +12,24 @@ // //===----------------------------------------------------------------------===// -/* NOT @testable */ import AsyncHTTPClient // Tests that really need @testable go into HTTP2ClientInternalTests.swift -#if canImport(Network) -import Network -#endif +import AsyncHTTPClient // NOT @testable - tests that really need @testable go into HTTP2ClientInternalTests.swift import Logging +import NIOConcurrencyHelpers import NIOCore +import NIOFoundationCompat import NIOHTTP1 +import NIOHTTP2 import NIOPosix import NIOSSL import XCTest +#if canImport(Network) +import Network +#endif + class HTTP2ClientTests: XCTestCase { func makeDefaultHTTPClient( - eventLoopGroupProvider: HTTPClient.EventLoopGroupProvider = .createNew + eventLoopGroupProvider: HTTPClient.EventLoopGroupProvider = .singleton ) -> HTTPClient { var config = HTTPClient.Configuration() config.tlsConfiguration = .clientDefault @@ -40,7 +44,7 @@ class HTTP2ClientTests: XCTestCase { func makeClientWithActiveHTTP2Connection( to bin: HTTPBin, - eventLoopGroupProvider: HTTPClient.EventLoopGroupProvider = .createNew + eventLoopGroupProvider: HTTPClient.EventLoopGroupProvider = .singleton ) -> HTTPClient { let client = self.makeDefaultHTTPClient(eventLoopGroupProvider: eventLoopGroupProvider) var response: HTTPClient.Response? @@ -68,7 +72,7 @@ class HTTP2ClientTests: XCTestCase { let client = self.makeDefaultHTTPClient() defer { XCTAssertNoThrow(try client.syncShutdown()) } var response: HTTPClient.Response? - let body = HTTPClient.Body.stream(length: nil) { writer in + let body = HTTPClient.Body.stream(contentLength: nil) { writer in writer.write(.byteBuffer(ByteBuffer(integer: UInt64(0)))).flatMap { writer.write(.byteBuffer(ByteBuffer(integer: UInt64(0)))) } @@ -84,7 +88,7 @@ class HTTP2ClientTests: XCTestCase { defer { XCTAssertNoThrow(try bin.shutdown()) } let client = self.makeDefaultHTTPClient() defer { XCTAssertNoThrow(try client.syncShutdown()) } - let body = HTTPClient.Body.stream(length: 12) { writer in + let body = HTTPClient.Body.stream(contentLength: 12) { writer in writer.write(.byteBuffer(ByteBuffer(integer: UInt64(0)))).flatMap { writer.write(.byteBuffer(ByteBuffer(integer: UInt64(0)))) } @@ -132,8 +136,8 @@ class HTTP2ClientTests: XCTestCase { let q = DispatchQueue(label: "worker \(w)") q.async(group: allDone) { func go() { - allWorkersReady.signal() // tell the driver we're ready - allWorkersGo.wait() // wait for the driver to let us go + allWorkersReady.signal() // tell the driver we're ready + allWorkersGo.wait() // wait for the driver to let us go for _ in 0..] = [] - XCTAssertNoThrow(results = try EventLoopFuture - .whenAllComplete(responses, on: clientGroup.next()) - .timeout(after: .seconds(2)) - .wait()) + XCTAssertNoThrow( + results = + try EventLoopFuture + .whenAllComplete(responses, on: clientGroup.next()) + .timeout(after: .seconds(2)) + .wait() + ) for result in results { switch result { @@ -276,18 +284,19 @@ class HTTP2ClientTests: XCTestCase { XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(bin.port)")) guard let request = maybeRequest else { return } - var task: HTTPClient.Task! + let taskBox = NIOLockedValueBox?>(nil) let delegate = HeadReceivedCallback { _ in // request is definitely running because we just received a head from the server - task.cancel() + taskBox.withLockedValue { $0 }!.cancel() } - task = client.execute( + let task = client.execute( request: request, delegate: delegate ) + taskBox.withLockedValue { $0 = task } XCTAssertThrowsError(try task.futureResult.timeout(after: .seconds(2)).wait()) { - XCTAssertEqual($0 as? HTTPClientError, .cancelled) + XCTAssertEqualTypeAndValue($0, HTTPClientError.cancelled) } } @@ -301,7 +310,7 @@ class HTTP2ClientTests: XCTestCase { config.httpVersion = .automatic config.timeout.read = .milliseconds(100) let client = HTTPClient( - eventLoopGroupProvider: .createNew, + eventLoopGroupProvider: .singleton, configuration: config, backgroundActivityLogger: Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) ) @@ -322,7 +331,8 @@ class HTTP2ClientTests: XCTestCase { config.tlsConfiguration = tlsConfig config.httpVersion = .automatic let client = HTTPClient( - eventLoopGroupProvider: .createNew, + // TODO: Test fails if the provided ELG is a multi-threaded NIOTSEventLoopGroup (probably racy) + eventLoopGroupProvider: .shared(bin.group), configuration: config, backgroundActivityLogger: Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) ) @@ -352,18 +362,20 @@ class HTTP2ClientTests: XCTestCase { guard let request = maybeRequest else { return } let tasks = (0..<100).map { _ -> HTTPClient.Task in - var task: HTTPClient.Task! + let taskBox = NIOLockedValueBox?>(nil) + let delegate = HeadReceivedCallback { _ in // request is definitely running because we just received a head from the server cancelPool.next().execute { // canceling from a different thread - task.cancel() + taskBox.withLockedValue { $0 }!.cancel() } } - task = client.execute( + let task = client.execute( request: request, delegate: delegate ) + taskBox.withLockedValue { $0 = task } return task } @@ -375,7 +387,7 @@ class HTTP2ClientTests: XCTestCase { } func testPlatformConnectErrorIsForwardedOnTimeout() { - let bin = HTTPBin(.http2(compress: false)) + let bin = HTTPBin(.http2(compress: false), reusePort: true) let clientGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) let el1 = clientGroup.next() let el2 = clientGroup.next() @@ -396,27 +408,35 @@ class HTTP2ClientTests: XCTestCase { XCTAssertNoThrow(maybeRequest1 = try HTTPClient.Request(url: "https://localhost:\(bin.port)/get")) guard let request1 = maybeRequest1 else { return } - let task1 = client.execute(request: request1, delegate: ResponseAccumulator(request: request1), eventLoop: .delegateAndChannel(on: el1)) + let task1 = client.execute( + request: request1, + delegate: ResponseAccumulator(request: request1), + eventLoop: .delegateAndChannel(on: el1) + ) var response1: ResponseAccumulator.Response? XCTAssertNoThrow(response1 = try task1.wait()) XCTAssertEqual(.ok, response1?.status) XCTAssertEqual(response1?.version, .http2) let serverPort = bin.port - XCTAssertNoThrow(try bin.shutdown()) - // client is now in HTTP/2 state and the HTTPBin is closed - // start a new server on the old port which closes all connections immediately + let serverGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try serverGroup.syncShutdownGracefully()) } var maybeServer: Channel? - XCTAssertNoThrow(maybeServer = try ServerBootstrap(group: serverGroup) - .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - .childChannelInitializer { channel in - channel.close() - } - .childChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - .bind(host: "127.0.0.1", port: serverPort) - .wait()) + XCTAssertNoThrow( + maybeServer = try ServerBootstrap(group: serverGroup) + .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEPORT), value: 1) + .childChannelInitializer { channel in + channel.close() + } + .childChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .bind(host: "127.0.0.1", port: serverPort) + .wait() + ) + // shutting down the old server closes all connections immediately + XCTAssertNoThrow(try bin.shutdown()) + // client is now in HTTP/2 state and the HTTPBin is closed guard let server = maybeServer else { return } defer { XCTAssertNoThrow(try server.close().wait()) } @@ -424,7 +444,11 @@ class HTTP2ClientTests: XCTestCase { XCTAssertNoThrow(maybeRequest2 = try HTTPClient.Request(url: "https://localhost:\(serverPort)/")) guard let request2 = maybeRequest2 else { return } - let task2 = client.execute(request: request2, delegate: ResponseAccumulator(request: request2), eventLoop: .delegateAndChannel(on: el2)) + let task2 = client.execute( + request: request2, + delegate: ResponseAccumulator(request: request2), + eventLoop: .delegateAndChannel(on: el2) + ) XCTAssertThrowsError(try task2.wait()) { error in XCTAssertNil( error as? HTTPClientError, @@ -432,12 +456,103 @@ class HTTP2ClientTests: XCTestCase { ) } } + + func testMassiveDownload() { + let bin = HTTPBin(.http2(compress: false)) + defer { XCTAssertNoThrow(try bin.shutdown()) } + let client = self.makeDefaultHTTPClient() + defer { XCTAssertNoThrow(try client.syncShutdown()) } + var response: HTTPClient.Response? + XCTAssertNoThrow(response = try client.get(url: "https://localhost:\(bin.port)/mega-chunked").wait()) + + XCTAssertEqual(.ok, response?.status) + XCTAssertEqual(response?.version, .http2) + XCTAssertEqual(response?.body?.readableBytes, 10_000) + } + + func testSimplePost() { + let bin = HTTPBin(.http2(compress: false)) + defer { XCTAssertNoThrow(try bin.shutdown()) } + let client = self.makeDefaultHTTPClient() + defer { XCTAssertNoThrow(try client.syncShutdown()) } + var response: HTTPClient.Response? + XCTAssertNoThrow( + response = try client.post( + url: "https://localhost:\(bin.port)/post", + body: .byteBuffer(ByteBuffer(repeating: 0, count: 12345)) + ).wait() + ) + XCTAssertEqual(.ok, response?.status) + XCTAssertEqual(response?.version, .http2) + XCTAssertEqual( + String(buffer: ByteBuffer(repeating: 0, count: 12345)), + try response?.body.map { body in + try JSONDecoder().decode(RequestInfo.self, from: body) + }?.data + ) + } + + func testHugePost() { + // Regression test for https://github.com/swift-server/async-http-client/issues/784 + let group = MultiThreadedEventLoopGroup(numberOfThreads: 2) // This needs to be more than 1! + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + var serverH2Settings: HTTP2Settings = HTTP2Settings() + serverH2Settings.append(HTTP2Setting(parameter: .maxFrameSize, value: 16 * 1024 * 1024 - 1)) + serverH2Settings.append(HTTP2Setting(parameter: .initialWindowSize, value: Int(Int32.max))) + let bin = HTTPBin( + .http2(compress: false, settings: serverH2Settings) + ) + defer { XCTAssertNoThrow(try bin.shutdown()) } + var clientConfig = HTTPClient.Configuration() + clientConfig.tlsConfiguration = .clientDefault + clientConfig.tlsConfiguration?.certificateVerification = .none + clientConfig.httpVersion = .automatic + let client = HTTPClient( + eventLoopGroupProvider: .shared(group), + configuration: clientConfig, + backgroundActivityLogger: Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) + ) + defer { XCTAssertNoThrow(try client.syncShutdown()) } + + let loop1 = group.next() + let loop2 = group.next() + precondition(loop1 !== loop2, "bug in test setup, need two distinct loops") + + XCTAssertNoThrow( + try client.execute( + request: .init(url: "https://localhost:\(bin.port)/get"), + eventLoop: .delegateAndChannel(on: loop1) // This will force the channel to live on `loop1`. + ).wait() + ) + var response: HTTPClient.Response? + let byteCount = 1024 * 1024 * 1024 // 1 GiB (unfortunately it has to be that big to trigger the bug) + XCTAssertNoThrow( + response = try client.execute( + request: HTTPClient.Request( + url: "https://localhost:\(bin.port)/post-respond-with-byte-count", + method: .POST, + body: .data(Data(repeating: 0, count: byteCount)) + ), + eventLoop: .delegate(on: loop2) + ).wait() + ) + XCTAssertEqual(.ok, response?.status) + XCTAssertEqual(response?.version, .http2) + XCTAssertEqual( + "\(byteCount)", + try response?.body.map { body in + try JSONDecoder().decode(RequestInfo.self, from: body) + }?.data + ) + } } private final class HeadReceivedCallback: HTTPClientResponseDelegate { typealias Response = Void - private let didReceiveHeadCallback: (HTTPResponseHead) -> Void - init(didReceiveHead: @escaping (HTTPResponseHead) -> Void) { + private let didReceiveHeadCallback: @Sendable (HTTPResponseHead) -> Void + init(didReceiveHead: @escaping @Sendable (HTTPResponseHead) -> Void) { self.didReceiveHeadCallback = didReceiveHead } @@ -458,11 +573,17 @@ private final class SendHeaderAndWaitChannelHandler: ChannelInboundHandler { let requestPart = self.unwrapInboundIn(data) switch requestPart { case .head: - context.writeAndFlush(self.wrapOutboundOut(.head(HTTPResponseHead( - version: HTTPVersion(major: 1, minor: 1), - status: .ok - )) - ), promise: nil) + context.writeAndFlush( + self.wrapOutboundOut( + .head( + HTTPResponseHead( + version: HTTPVersion(major: 1, minor: 1), + status: .ok + ) + ) + ), + promise: nil + ) case .body, .end: return } diff --git a/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests+XCTest.swift deleted file mode 100644 index 9f9582d9f..000000000 --- a/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests+XCTest.swift +++ /dev/null @@ -1,34 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTP2ConnectionTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTP2ConnectionTests { - static var allTests: [(String, (HTTP2ConnectionTests) -> () throws -> Void)] { - return [ - ("testCreateNewConnectionFailureClosedIO", testCreateNewConnectionFailureClosedIO), - ("testSimpleGetRequest", testSimpleGetRequest), - ("testEveryDoneRequestLeadsToAStreamAvailableCall", testEveryDoneRequestLeadsToAStreamAvailableCall), - ("testCancelAllRunningRequests", testCancelAllRunningRequests), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests.swift b/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests.swift index fab866867..14a4d5630 100644 --- a/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests.swift @@ -12,17 +12,20 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOConcurrencyHelpers import NIOCore import NIOEmbedded +import NIOHPACK import NIOHTTP1 +import NIOHTTP2 import NIOPosix import NIOSSL import NIOTestUtils import XCTest +@testable import AsyncHTTPClient + class HTTP2ConnectionTests: XCTestCase { func testCreateNewConnectionFailureClosedIO() { let embedded = EmbeddedChannel() @@ -33,13 +36,41 @@ class HTTP2ConnectionTests: XCTestCase { embedded.embeddedEventLoop.run() let logger = Logger(label: "test.http2.connection") - XCTAssertThrowsError(try HTTP2Connection.start( + XCTAssertThrowsError( + try HTTP2Connection.start( + channel: embedded, + connectionID: 0, + delegate: TestHTTP2ConnectionDelegate(), + decompression: .disabled, + maximumConnectionUses: nil, + logger: logger + ).map { _ in }.nonisolated().wait() + ) + } + + func testConnectionToleratesShutdownEventsAfterAlreadyClosed() { + let embedded = EmbeddedChannel() + XCTAssertNoThrow(try embedded.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 3000)).wait()) + + let logger = Logger(label: "test.http2.connection") + let connection = HTTP2Connection( channel: embedded, connectionID: 0, + decompression: .disabled, + maximumConnectionUses: nil, delegate: TestHTTP2ConnectionDelegate(), - configuration: .init(), logger: logger - ).wait()) + ) + let startFuture = connection._start0() + + XCTAssertNoThrow(try embedded.close().wait()) + // to really destroy the channel we need to tick once + embedded.embeddedEventLoop.run() + + XCTAssertThrowsError(try startFuture.wait()) + + // should not crash + connection.sendableView.shutdown() } func testSimpleGetRequest() { @@ -52,12 +83,13 @@ class HTTP2ConnectionTests: XCTestCase { let connectionCreator = TestConnectionCreator() let delegate = TestHTTP2ConnectionDelegate() - var maybeHTTP2Connection: HTTP2Connection? - XCTAssertNoThrow(maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( - to: httpBin.port, - delegate: delegate, - on: eventLoop - ) + var maybeHTTP2Connection: HTTP2Connection.SendableView? + XCTAssertNoThrow( + maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( + to: httpBin.port, + delegate: delegate, + on: eventLoop + ) ) guard let http2Connection = maybeHTTP2Connection else { return XCTFail("Expected to have an HTTP2 connection here.") @@ -66,22 +98,23 @@ class HTTP2ConnectionTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoop, logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .distantFuture, - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoop, logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .distantFuture, + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to have a request bag at this point") } http2Connection.executeRequest(requestBag) - XCTAssertEqual(delegate.hitStreamClosed, 0) var maybeResponse: HTTPClient.Response? XCTAssertNoThrow(maybeResponse = try requestBag.task.futureResult.wait()) XCTAssertEqual(maybeResponse?.status, .ok) @@ -108,12 +141,14 @@ class HTTP2ConnectionTests: XCTestCase { let connectionCreator = TestConnectionCreator() let delegate = TestHTTP2ConnectionDelegate() - var maybeHTTP2Connection: HTTP2Connection? - XCTAssertNoThrow(maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( - to: httpBin.port, - delegate: delegate, - on: eventLoop - )) + var maybeHTTP2Connection: HTTP2Connection.SendableView? + XCTAssertNoThrow( + maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( + to: httpBin.port, + delegate: delegate, + on: eventLoop + ) + ) guard let http2Connection = maybeHTTP2Connection else { return XCTFail("Expected to have an HTTP2 connection here.") } @@ -127,15 +162,17 @@ class HTTP2ConnectionTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoop, logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .distantFuture, - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoop, logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .distantFuture, + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to have a request bag at this point") } @@ -172,12 +209,13 @@ class HTTP2ConnectionTests: XCTestCase { let connectionCreator = TestConnectionCreator() let delegate = TestHTTP2ConnectionDelegate() - var maybeHTTP2Connection: HTTP2Connection? - XCTAssertNoThrow(maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( - to: httpBin.port, - delegate: delegate, - on: eventLoop - ) + var maybeHTTP2Connection: HTTP2Connection.SendableView? + XCTAssertNoThrow( + maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( + to: httpBin.port, + delegate: delegate, + on: eventLoop + ) ) guard let http2Connection = maybeHTTP2Connection else { return XCTFail("Expected to have an HTTP2 connection here.") @@ -189,15 +227,17 @@ class HTTP2ConnectionTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoop, logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .distantFuture, - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoop, logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .distantFuture, + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to have a request bag at this point") } @@ -219,9 +259,132 @@ class HTTP2ConnectionTests: XCTestCase { XCTAssertNoThrow(try http2Connection.closeFuture.wait()) } + + func testChildStreamsAreRemovedFromTheOpenChannelListOnceTheRequestIsDone() { + class SucceedPromiseOnRequestHandler: ChannelInboundHandler { + typealias InboundIn = HTTPServerRequestPart + typealias OutboundOut = HTTPServerResponsePart + + let dataArrivedPromise: EventLoopPromise + let triggerResponseFuture: EventLoopFuture + + init(dataArrivedPromise: EventLoopPromise, triggerResponseFuture: EventLoopFuture) { + self.dataArrivedPromise = dataArrivedPromise + self.triggerResponseFuture = triggerResponseFuture + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + self.dataArrivedPromise.succeed(()) + + self.triggerResponseFuture.hop(to: context.eventLoop).assumeIsolated().whenSuccess { + switch self.unwrapInboundIn(data) { + case .head: + context.write(self.wrapOutboundOut(.head(.init(version: .http2, status: .ok))), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + case .body, .end: + break + } + } + } + } + + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + let eventLoop = eventLoopGroup.next() + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + + let serverReceivedRequestPromise = eventLoop.makePromise(of: Void.self) + let triggerResponsePromise = eventLoop.makePromise(of: Void.self) + let httpBin = HTTPBin(.http2(compress: false)) { _ in + SucceedPromiseOnRequestHandler( + dataArrivedPromise: serverReceivedRequestPromise, + triggerResponseFuture: triggerResponsePromise.futureResult + ) + } + defer { XCTAssertNoThrow(try httpBin.shutdown()) } + + let connectionCreator = TestConnectionCreator() + let delegate = TestHTTP2ConnectionDelegate() + var maybeHTTP2Connection: HTTP2Connection.SendableView? + XCTAssertNoThrow( + maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( + to: httpBin.port, + delegate: delegate, + on: eventLoop + ) + ) + guard let http2Connection = maybeHTTP2Connection else { + return XCTFail("Expected to have an HTTP2 connection here.") + } + + var maybeRequest: HTTPClient.Request? + var maybeRequestBag: RequestBag? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoop, logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .distantFuture, + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) + guard let requestBag = maybeRequestBag else { + return XCTFail("Expected to have a request bag at this point") + } + + http2Connection.executeRequest(requestBag) + + XCTAssertNoThrow(try serverReceivedRequestPromise.futureResult.wait()) + var channelCount: Int? + XCTAssertNoThrow( + channelCount = try eventLoop.submit { http2Connection.__forTesting_getStreamChannels().count }.wait() + ) + XCTAssertEqual(channelCount, 1) + triggerResponsePromise.succeed(()) + + XCTAssertNoThrow(try requestBag.task.futureResult.wait()) + + // this is racy. for this reason we allow a couple of tries + var retryCount = 0 + let maxRetries = 1000 + while retryCount < maxRetries { + XCTAssertNoThrow( + channelCount = try eventLoop.submit { http2Connection.__forTesting_getStreamChannels().count }.wait() + ) + if channelCount == 0 { + break + } + retryCount += 1 + } + XCTAssertLessThan(retryCount, maxRetries) + } + + func testServerPushIsDisabled() { + let embedded = EmbeddedChannel() + let logger = Logger(label: "test.http2.connection") + let connection = HTTP2Connection( + channel: embedded, + connectionID: 0, + decompression: .disabled, + maximumConnectionUses: nil, + delegate: TestHTTP2ConnectionDelegate(), + logger: logger + ) + _ = connection._start0() + + let settingsFrame = HTTP2Frame(streamID: 0, payload: .settings(.settings([]))) + XCTAssertNoThrow(try connection.channel.writeAndFlush(settingsFrame).wait()) + + let pushPromiseFrame = HTTP2Frame(streamID: 0, payload: .pushPromise(.init(pushedStreamID: 1, headers: [:]))) + XCTAssertThrowsError(try connection.channel.writeAndFlush(pushPromiseFrame).wait()) { error in + XCTAssertNotNil(error as? NIOHTTP2Errors.PushInViolationOfSetting) + } + } } -class TestConnectionCreator { +final class TestConnectionCreator { enum Error: Swift.Error { case alreadyCreatingAnotherConnection case wantedHTTP2ConnectionButGotHTTP1 @@ -230,12 +393,11 @@ class TestConnectionCreator { enum State { case idle - case waitingForHTTP1Connection(EventLoopPromise) - case waitingForHTTP2Connection(EventLoopPromise) + case waitingForHTTP1Connection(EventLoopPromise) + case waitingForHTTP2Connection(EventLoopPromise) } - private var state: State = .idle - private let lock = Lock() + private let lock = NIOLockedValueBox(.idle) init() {} @@ -245,7 +407,7 @@ class TestConnectionCreator { connectionID: HTTPConnectionPool.Connection.ID = 0, on eventLoop: EventLoop, logger: Logger = .init(label: "test") - ) throws -> HTTP1Connection { + ) throws -> HTTP1Connection.SendableView { let request = try! HTTPClient.Request(url: "https://localhost:\(port)") var tlsConfiguration = TLSConfiguration.makeClientConfiguration() @@ -259,13 +421,13 @@ class TestConnectionCreator { sslContextCache: .init() ) - let promise = try self.lock.withLock { () -> EventLoopPromise in - guard case .idle = self.state else { + let promise = try self.lock.withLockedValue { state in + guard case .idle = state else { throw Error.alreadyCreatingAnotherConnection } - let promise = eventLoop.makePromise(of: HTTP1Connection.self) - self.state = .waitingForHTTP1Connection(promise) + let promise = eventLoop.makePromise(of: HTTP1Connection.SendableView.self) + state = .waitingForHTTP1Connection(promise) return promise } @@ -288,7 +450,7 @@ class TestConnectionCreator { connectionID: HTTPConnectionPool.Connection.ID = 0, on eventLoop: EventLoop, logger: Logger = .init(label: "test") - ) throws -> HTTP2Connection { + ) throws -> HTTP2Connection.SendableView { let request = try! HTTPClient.Request(url: "https://localhost:\(port)") var tlsConfiguration = TLSConfiguration.makeClientConfiguration() @@ -302,13 +464,13 @@ class TestConnectionCreator { sslContextCache: .init() ) - let promise = try self.lock.withLock { () -> EventLoopPromise in - guard case .idle = self.state else { + let promise = try self.lock.withLockedValue { state in + guard case .idle = state else { throw Error.alreadyCreatingAnotherConnection } - let promise = eventLoop.makePromise(of: HTTP2Connection.self) - self.state = .waitingForHTTP2Connection(promise) + let promise = eventLoop.makePromise(of: HTTP2Connection.SendableView.self) + state = .waitingForHTTP2Connection(promise) return promise } @@ -327,7 +489,7 @@ class TestConnectionCreator { } extension TestConnectionCreator: HTTPConnectionRequester { - enum EitherPromiseWrapper { + enum EitherPromiseWrapper: Sendable { case succeed(EventLoopPromise, SucceedType) case fail(EventLoopPromise, Error) @@ -341,37 +503,38 @@ extension TestConnectionCreator: HTTPConnectionRequester { } } - func http1ConnectionCreated(_ connection: HTTP1Connection) { - let wrapper = self.lock.withLock { () -> (EitherPromiseWrapper) in + func http1ConnectionCreated(_ connection: HTTP1Connection.SendableView) { + let wrapper: EitherPromiseWrapper = self.lock + .withLockedValue { state in - switch self.state { - case .waitingForHTTP1Connection(let promise): - return .succeed(promise, connection) + switch state { + case .waitingForHTTP1Connection(let promise): + return .succeed(promise, connection) - case .waitingForHTTP2Connection(let promise): - return .fail(promise, Error.wantedHTTP2ConnectionButGotHTTP1) + case .waitingForHTTP2Connection(let promise): + return .fail(promise, Error.wantedHTTP2ConnectionButGotHTTP1) - case .idle: - preconditionFailure("Invalid state: \(self.state)") + case .idle: + preconditionFailure("Invalid state: \(state)") + } } - } wrapper.complete() } - func http2ConnectionCreated(_ connection: HTTP2Connection, maximumStreams: Int) { - let wrapper = self.lock.withLock { () -> (EitherPromiseWrapper) in + func http2ConnectionCreated(_ connection: HTTP2Connection.SendableView, maximumStreams: Int) { + let wrapper: EitherPromiseWrapper = self.lock + .withLockedValue { state in + switch state { + case .waitingForHTTP1Connection(let promise): + return .fail(promise, Error.wantedHTTP1ConnectionButGotHTTP2) - switch self.state { - case .waitingForHTTP1Connection(let promise): - return .fail(promise, Error.wantedHTTP1ConnectionButGotHTTP2) + case .waitingForHTTP2Connection(let promise): + return .succeed(promise, connection) - case .waitingForHTTP2Connection(let promise): - return .succeed(promise, connection) - - case .idle: - preconditionFailure("Invalid state: \(self.state)") + case .idle: + preconditionFailure("Invalid state: \(state)") + } } - } wrapper.complete() } @@ -390,93 +553,100 @@ extension TestConnectionCreator: HTTPConnectionRequester { } func failedToCreateHTTPConnection(_: HTTPConnectionPool.Connection.ID, error: Swift.Error) { - let wrapper = self.lock.withLock { () -> (FailPromiseWrapper) in + let wrapper: FailPromiseWrapper = self.lock + .withLockedValue { state in - switch self.state { - case .waitingForHTTP1Connection(let promise): - return .type1(promise) + switch state { + case .waitingForHTTP1Connection(let promise): + return .type1(promise) - case .waitingForHTTP2Connection(let promise): - return .type2(promise) + case .waitingForHTTP2Connection(let promise): + return .type2(promise) - case .idle: - preconditionFailure("Invalid state: \(self.state)") + case .idle: + preconditionFailure("Invalid state: \(state)") + } } - } wrapper.fail(error) } + + func waitingForConnectivity(_: HTTPConnectionPool.Connection.ID, error: Swift.Error) { + preconditionFailure("TODO") + } } -class TestHTTP2ConnectionDelegate: HTTP2ConnectionDelegate { +final class TestHTTP2ConnectionDelegate: HTTP2ConnectionDelegate { var hitStreamClosed: Int { - self.lock.withLock { self._hitStreamClosed } + self.lock.withLockedValue { $0.hitStreamClosed } } var hitGoAwayReceived: Int { - self.lock.withLock { self._hitGoAwayReceived } + self.lock.withLockedValue { $0.hitGoAwayReceived } } var hitConnectionClosed: Int { - self.lock.withLock { self._hitConnectionClosed } + self.lock.withLockedValue { $0.hitConnectionClosed } } var maxStreamSetting: Int { - self.lock.withLock { self._maxStreamSetting } + self.lock.withLockedValue { $0.maxStreamSetting } } - private let lock = Lock() - private var _hitStreamClosed: Int = 0 - private var _hitGoAwayReceived: Int = 0 - private var _hitConnectionClosed: Int = 0 - private var _maxStreamSetting: Int = 100 + private let lock = NIOLockedValueBox(.init()) + private struct Counts { + var hitStreamClosed: Int = 0 + var hitGoAwayReceived: Int = 0 + var hitConnectionClosed: Int = 0 + var maxStreamSetting: Int = 100 + } init() {} - func http2Connection(_: HTTP2Connection, newMaxStreamSetting: Int) {} + func http2Connection(_: HTTPConnectionPool.Connection.ID, newMaxStreamSetting: Int) {} - func http2ConnectionStreamClosed(_: HTTP2Connection, availableStreams: Int) { - self.lock.withLockVoid { - self._hitStreamClosed += 1 + func http2ConnectionStreamClosed(_: HTTPConnectionPool.Connection.ID, availableStreams: Int) { + self.lock.withLockedValue { + $0.hitStreamClosed += 1 } } - func http2ConnectionGoAwayReceived(_: HTTP2Connection) { - self.lock.withLockVoid { - self._hitGoAwayReceived += 1 + func http2ConnectionGoAwayReceived(_: HTTPConnectionPool.Connection.ID) { + self.lock.withLockedValue { + $0.hitGoAwayReceived += 1 } } - func http2ConnectionClosed(_: HTTP2Connection) { - self.lock.withLockVoid { - self._hitConnectionClosed += 1 + func http2ConnectionClosed(_: HTTPConnectionPool.Connection.ID) { + self.lock.withLockedValue { + $0.hitConnectionClosed += 1 } } } final class EmptyHTTP2ConnectionDelegate: HTTP2ConnectionDelegate { - func http2Connection(_: HTTP2Connection, newMaxStreamSetting: Int) { + func http2Connection(_: HTTPConnectionPool.Connection.ID, newMaxStreamSetting: Int) { preconditionFailure("Unimplemented") } - func http2ConnectionStreamClosed(_: HTTP2Connection, availableStreams: Int) { + func http2ConnectionStreamClosed(_: HTTPConnectionPool.Connection.ID, availableStreams: Int) { preconditionFailure("Unimplemented") } - func http2ConnectionGoAwayReceived(_: HTTP2Connection) { + func http2ConnectionGoAwayReceived(_: HTTPConnectionPool.Connection.ID) { preconditionFailure("Unimplemented") } - func http2ConnectionClosed(_: HTTP2Connection) { + func http2ConnectionClosed(_: HTTPConnectionPool.Connection.ID) { preconditionFailure("Unimplemented") } } final class EmptyHTTP1ConnectionDelegate: HTTP1ConnectionDelegate { - func http1ConnectionReleased(_: HTTP1Connection) { + func http1ConnectionReleased(_: HTTPConnectionPool.Connection.ID) { preconditionFailure("Unimplemented") } - func http1ConnectionClosed(_: HTTP1Connection) { + func http1ConnectionClosed(_: HTTPConnectionPool.Connection.ID) { preconditionFailure("Unimplemented") } } diff --git a/Tests/AsyncHTTPClientTests/HTTP2IdleHandlerTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTP2IdleHandlerTests+XCTest.swift deleted file mode 100644 index 1b9558105..000000000 --- a/Tests/AsyncHTTPClientTests/HTTP2IdleHandlerTests+XCTest.swift +++ /dev/null @@ -1,41 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTP2IdleHandlerTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTP2IdleHandlerTests { - static var allTests: [(String, (HTTP2IdleHandlerTests) -> () throws -> Void)] { - return [ - ("testReceiveSettingsWithMaxConcurrentStreamSetting", testReceiveSettingsWithMaxConcurrentStreamSetting), - ("testReceiveSettingsWithoutMaxConcurrentStreamSetting", testReceiveSettingsWithoutMaxConcurrentStreamSetting), - ("testEmptySettingsDontOverwriteMaxConcurrentStreamSetting", testEmptySettingsDontOverwriteMaxConcurrentStreamSetting), - ("testOverwriteMaxConcurrentStreamSetting", testOverwriteMaxConcurrentStreamSetting), - ("testGoAwayReceivedBeforeSettings", testGoAwayReceivedBeforeSettings), - ("testGoAwayReceivedAfterSettings", testGoAwayReceivedAfterSettings), - ("testCloseEventBeforeFirstSettings", testCloseEventBeforeFirstSettings), - ("testCloseEventWhileNoOpenStreams", testCloseEventWhileNoOpenStreams), - ("testCloseEventWhileThereAreOpenStreams", testCloseEventWhileThereAreOpenStreams), - ("testGoAwayWhileThereAreOpenStreams", testGoAwayWhileThereAreOpenStreams), - ("testReceiveSettingsAndGoAwayAfterClientSideClose", testReceiveSettingsAndGoAwayAfterClientSideClose), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTP2IdleHandlerTests.swift b/Tests/AsyncHTTPClientTests/HTTP2IdleHandlerTests.swift index 355969c6a..f2b56daa0 100644 --- a/Tests/AsyncHTTPClientTests/HTTP2IdleHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP2IdleHandlerTests.swift @@ -12,13 +12,14 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOCore import NIOEmbedded import NIOHTTP2 import XCTest +@testable import AsyncHTTPClient + class HTTP2IdleHandlerTests: XCTestCase { func testReceiveSettingsWithMaxConcurrentStreamSetting() { let delegate = MockHTTP2IdleHandlerDelegate() @@ -26,7 +27,10 @@ class HTTP2IdleHandlerTests: XCTestCase { let embedded = EmbeddedChannel(handlers: [idleHandler]) XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) - let settingsFrame = HTTP2Frame(streamID: 0, payload: .settings(.settings([.init(parameter: .maxConcurrentStreams, value: 10)]))) + let settingsFrame = HTTP2Frame( + streamID: 0, + payload: .settings(.settings([.init(parameter: .maxConcurrentStreams, value: 10)])) + ) XCTAssertEqual(delegate.maxStreams, nil) XCTAssertNoThrow(try embedded.writeInbound(settingsFrame)) XCTAssertEqual(delegate.maxStreams, 10) @@ -41,7 +45,11 @@ class HTTP2IdleHandlerTests: XCTestCase { let settingsFrame = HTTP2Frame(streamID: 0, payload: .settings(.settings([]))) XCTAssertEqual(delegate.maxStreams, nil) XCTAssertNoThrow(try embedded.writeInbound(settingsFrame)) - XCTAssertEqual(delegate.maxStreams, 100, "Expected to assume 100 maxConcurrentConnection, if no setting was present") + XCTAssertEqual( + delegate.maxStreams, + 100, + "Expected to assume 100 maxConcurrentConnection, if no setting was present" + ) } func testEmptySettingsDontOverwriteMaxConcurrentStreamSetting() { @@ -50,7 +58,10 @@ class HTTP2IdleHandlerTests: XCTestCase { let embedded = EmbeddedChannel(handlers: [idleHandler]) XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) - let settingsFrame = HTTP2Frame(streamID: 0, payload: .settings(.settings([.init(parameter: .maxConcurrentStreams, value: 10)]))) + let settingsFrame = HTTP2Frame( + streamID: 0, + payload: .settings(.settings([.init(parameter: .maxConcurrentStreams, value: 10)])) + ) XCTAssertEqual(delegate.maxStreams, nil) XCTAssertNoThrow(try embedded.writeInbound(settingsFrame)) XCTAssertEqual(delegate.maxStreams, 10) @@ -66,12 +77,18 @@ class HTTP2IdleHandlerTests: XCTestCase { let embedded = EmbeddedChannel(handlers: [idleHandler]) XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) - let settingsFrame = HTTP2Frame(streamID: 0, payload: .settings(.settings([.init(parameter: .maxConcurrentStreams, value: 10)]))) + let settingsFrame = HTTP2Frame( + streamID: 0, + payload: .settings(.settings([.init(parameter: .maxConcurrentStreams, value: 10)])) + ) XCTAssertEqual(delegate.maxStreams, nil) XCTAssertNoThrow(try embedded.writeInbound(settingsFrame)) XCTAssertEqual(delegate.maxStreams, 10) - let emptySettings = HTTP2Frame(streamID: 0, payload: .settings(.settings([.init(parameter: .maxConcurrentStreams, value: 20)]))) + let emptySettings = HTTP2Frame( + streamID: 0, + payload: .settings(.settings([.init(parameter: .maxConcurrentStreams, value: 20)])) + ) XCTAssertNoThrow(try embedded.writeInbound(emptySettings)) XCTAssertEqual(delegate.maxStreams, 20) } @@ -83,7 +100,10 @@ class HTTP2IdleHandlerTests: XCTestCase { XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) let randomStreamID = HTTP2StreamID((0.. () throws -> Void)] { - return [ - ("testProxySOCKS", testProxySOCKS), - ("testProxySOCKSBogusAddress", testProxySOCKSBogusAddress), - ("testProxySOCKSFailureNoServer", testProxySOCKSFailureNoServer), - ("testProxySOCKSFailureInvalidServer", testProxySOCKSFailureInvalidServer), - ("testProxySOCKSMisbehavingServer", testProxySOCKSMisbehavingServer), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPClient+SOCKSTests.swift b/Tests/AsyncHTTPClientTests/HTTPClient+SOCKSTests.swift index 5fdc5ac61..d5e1c895b 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClient+SOCKSTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClient+SOCKSTests.swift @@ -12,9 +12,11 @@ // //===----------------------------------------------------------------------===// -/* NOT @testable */ import AsyncHTTPClient // Tests that need @testable go into HTTPClientInternalTests.swift +import AsyncHTTPClient // NOT @testable - tests that need @testable go into HTTPClientInternalTests.swift +import InMemoryLogging import Logging import NIOCore +import NIOHTTP1 import NIOPosix import NIOSOCKS import XCTest @@ -26,10 +28,10 @@ class HTTPClientSOCKSTests: XCTestCase { var serverGroup: EventLoopGroup! var defaultHTTPBin: HTTPBin! var defaultClient: HTTPClient! - var backgroundLogStore: CollectEverythingLogHandler.LogStore! + var backgroundLogStore: InMemoryLogHandler! var defaultHTTPBinURLPrefix: String { - return "http://localhost:\(self.defaultHTTPBin.port)/" + "http://localhost:\(self.defaultHTTPBin.port)/" } override func setUp() { @@ -42,13 +44,12 @@ class HTTPClientSOCKSTests: XCTestCase { self.clientGroup = getDefaultEventLoopGroup(numberOfThreads: 1) self.serverGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) self.defaultHTTPBin = HTTPBin() - self.backgroundLogStore = CollectEverythingLogHandler.LogStore() - var backgroundLogger = Logger(label: "\(#function)", factory: { _ in - CollectEverythingLogHandler(logStore: self.backgroundLogStore!) - }) - backgroundLogger.logLevel = .trace - self.defaultClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - backgroundActivityLogger: backgroundLogger) + let (backgroundLogStore, backgroundLogger) = InMemoryLogHandler.makeLogger(logLevel: .trace) + self.backgroundLogStore = backgroundLogStore + self.defaultClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + backgroundActivityLogger: backgroundLogger + ) } override func tearDown() { @@ -75,8 +76,12 @@ class HTTPClientSOCKSTests: XCTestCase { func testProxySOCKS() throws { let socksBin = try MockSOCKSServer(expectedURL: "/socks/test", expectedResponse: "it works!") - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(proxy: .socksServer(host: "localhost", port: socksBin.port))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init( + proxy: .socksServer(host: "localhost", port: socksBin.port) + ).enableFastFailureModeForTesting() + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) @@ -90,8 +95,12 @@ class HTTPClientSOCKSTests: XCTestCase { } func testProxySOCKSBogusAddress() throws { - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(proxy: .socksServer(host: "127.0.."))) + let config = HTTPClient.Configuration(proxy: .socksServer(host: "127.0..")) + .enableFastFailureModeForTesting() + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: config + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) @@ -102,8 +111,13 @@ class HTTPClientSOCKSTests: XCTestCase { // there is no socks server, so we should fail func testProxySOCKSFailureNoServer() throws { let localHTTPBin = HTTPBin() - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(proxy: .socksServer(host: "localhost", port: localHTTPBin.port))) + let config = HTTPClient.Configuration(proxy: .socksServer(host: "localhost", port: localHTTPBin.port)) + .enableFastFailureModeForTesting() + + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: config + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) XCTAssertNoThrow(try localHTTPBin.shutdown()) @@ -113,8 +127,13 @@ class HTTPClientSOCKSTests: XCTestCase { // speak to a server that doesn't speak SOCKS func testProxySOCKSFailureInvalidServer() throws { - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(proxy: .socksServer(host: "localhost"))) + let config = HTTPClient.Configuration(proxy: .socksServer(host: "localhost")) + .enableFastFailureModeForTesting() + + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: config + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) } @@ -124,8 +143,13 @@ class HTTPClientSOCKSTests: XCTestCase { // test a handshake failure with a misbehaving server func testProxySOCKSMisbehavingServer() throws { let socksBin = try MockSOCKSServer(expectedURL: "/socks/test", expectedResponse: "it works!", misbehave: true) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(proxy: .socksServer(host: "localhost", port: socksBin.port))) + let config = HTTPClient.Configuration(proxy: .socksServer(host: "localhost", port: socksBin.port)) + .enableFastFailureModeForTesting() + + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: config + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) diff --git a/Tests/AsyncHTTPClientTests/HTTPClient+StructuredConcurrencyTests.swift b/Tests/AsyncHTTPClientTests/HTTPClient+StructuredConcurrencyTests.swift new file mode 100644 index 000000000..a7cc1f454 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/HTTPClient+StructuredConcurrencyTests.swift @@ -0,0 +1,101 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2025 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import AsyncHTTPClient +import NIO +import NIOFoundationCompat +import NIOHTTP1 +import XCTest + +final class HTTPClientStructuredConcurrencyTests: XCTestCase { + func testDoNothingWorks() async throws { + let actual = try await HTTPClient.withHTTPClient { httpClient in + "OK" + } + XCTAssertEqual("OK", actual) + } + + func testShuttingDownTheClientInBodyLeadsToError() async { + do { + let actual = try await HTTPClient.withHTTPClient { httpClient in + try await httpClient.shutdown() + return "OK" + } + XCTFail("Expected error, got \(actual)") + } catch let error as HTTPClientError where error == .alreadyShutdown { + // OK + } catch { + XCTFail("unexpected error: \(error)") + } + } + + func testBasicRequest() async throws { + let httpBin = HTTPBin() + defer { XCTAssertNoThrow(try httpBin.shutdown()) } + + let actualBytes = try await HTTPClient.withHTTPClient { httpClient in + let response = try await httpClient.get(url: httpBin.baseURL).get() + XCTAssertEqual(response.status, .ok) + return response.body ?? ByteBuffer(string: "n/a") + } + let actual = try JSONDecoder().decode(RequestInfo.self, from: actualBytes) + + XCTAssertGreaterThanOrEqual(actual.requestNumber, 0) + XCTAssertGreaterThanOrEqual(actual.connectionNumber, 0) + } + + func testClientIsShutDownAfterReturn() async throws { + let leakedClient = try await HTTPClient.withHTTPClient { httpClient in + httpClient + } + do { + try await leakedClient.shutdown() + XCTFail("unexpected, shutdown should have failed") + } catch let error as HTTPClientError where error == .alreadyShutdown { + // OK + } catch { + XCTFail("unexpected error: \(error)") + } + } + + func testClientIsShutDownOnThrowAlso() async throws { + struct TestError: Error { + var httpClient: HTTPClient + } + + let leakedClient: HTTPClient + do { + try await HTTPClient.withHTTPClient { httpClient in + throw TestError(httpClient: httpClient) + } + XCTFail("unexpected, shutdown should have failed") + return + } catch let error as TestError { + // OK + leakedClient = error.httpClient + } catch { + XCTFail("unexpected error: \(error)") + return + } + + do { + try await leakedClient.shutdown() + XCTFail("unexpected, shutdown should have failed") + } catch let error as HTTPClientError where error == .alreadyShutdown { + // OK + } catch { + XCTFail("unexpected error: \(error)") + } + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientBase.swift b/Tests/AsyncHTTPClientTests/HTTPClientBase.swift new file mode 100644 index 000000000..90ab12fe5 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/HTTPClientBase.swift @@ -0,0 +1,88 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2022 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import AsyncHTTPClient +import Atomics +import InMemoryLogging +import Logging +import NIOConcurrencyHelpers +import NIOCore +import NIOFoundationCompat +import NIOHTTP1 +import NIOHTTPCompression +import NIOPosix +import NIOSSL +import NIOTestUtils +import NIOTransportServices +import XCTest + +#if canImport(Network) +import Network +#endif + +class XCTestCaseHTTPClientTestsBaseClass: XCTestCase { + typealias Request = HTTPClient.Request + + var clientGroup: EventLoopGroup! + var serverGroup: EventLoopGroup! + var defaultHTTPBin: HTTPBin! + var defaultClient: HTTPClient! + var backgroundLogStore: InMemoryLogHandler! + + var defaultHTTPBinURLPrefix: String { + self.defaultHTTPBin.baseURL + } + + override func setUp() { + XCTAssertNil(self.clientGroup) + XCTAssertNil(self.serverGroup) + XCTAssertNil(self.defaultHTTPBin) + XCTAssertNil(self.defaultClient) + XCTAssertNil(self.backgroundLogStore) + + self.clientGroup = getDefaultEventLoopGroup(numberOfThreads: 1) + self.serverGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + self.defaultHTTPBin = HTTPBin() + let (backgroundLogStore, backgroundLogger) = InMemoryLogHandler.makeLogger(logLevel: .trace) + self.backgroundLogStore = backgroundLogStore + + self.defaultClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration().enableFastFailureModeForTesting(), + backgroundActivityLogger: backgroundLogger + ) + } + + override func tearDown() { + if let defaultClient = self.defaultClient { + XCTAssertNoThrow(try defaultClient.syncShutdown()) + self.defaultClient = nil + } + + XCTAssertNotNil(self.defaultHTTPBin) + XCTAssertNoThrow(try self.defaultHTTPBin.shutdown()) + self.defaultHTTPBin = nil + + XCTAssertNotNil(self.clientGroup) + XCTAssertNoThrow(try self.clientGroup.syncShutdownGracefully()) + self.clientGroup = nil + + XCTAssertNotNil(self.serverGroup) + XCTAssertNoThrow(try self.serverGroup.syncShutdownGracefully()) + self.serverGroup = nil + + XCTAssertNotNil(self.backgroundLogStore) + self.backgroundLogStore = nil + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientCookieTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientCookieTests+XCTest.swift deleted file mode 100644 index 7ecf54d4d..000000000 --- a/Tests/AsyncHTTPClientTests/HTTPClientCookieTests+XCTest.swift +++ /dev/null @@ -1,44 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTPClientCookieTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTPClientCookieTests { - static var allTests: [(String, (HTTPClientCookieTests) -> () throws -> Void)] { - return [ - ("testCookie", testCookie), - ("testEmptyValueCookie", testEmptyValueCookie), - ("testCookieDefaults", testCookieDefaults), - ("testCookieInit", testCookieInit), - ("testMalformedCookies", testMalformedCookies), - ("testExpires", testExpires), - ("testMaxAge", testMaxAge), - ("testDomain", testDomain), - ("testPath", testPath), - ("testSecure", testSecure), - ("testHttpOnly", testHttpOnly), - ("testCookieExpiresDateParsing", testCookieExpiresDateParsing), - ("testQuotedCookies", testQuotedCookies), - ("testCookieExpiresDateParsingWithNonEnglishLocale", testCookieExpiresDateParsingWithNonEnglishLocale), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientCookieTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientCookieTests.swift index 8b4c9adf6..fa9abb9d8 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientCookieTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientCookieTests.swift @@ -19,7 +19,8 @@ import XCTest class HTTPClientCookieTests: XCTestCase { func testCookie() { - let v = "key=value; PaTh=/path; DoMaIn=EXampLE.CoM; eXpIRes=Wed, 21 Oct 2015 07:28:00 GMT; max-AGE=42; seCURE; HTTPOnly" + let v = + "key=value; PaTh=/path; DoMaIn=EXampLE.CoM; eXpIRes=Wed, 21 Oct 2015 07:28:00 GMT; max-AGE=42; seCURE; HTTPOnly" guard let c = HTTPClient.Cookie(header: v, defaultDomain: "exAMPle.cOm") else { XCTFail("Failed to parse cookie") return @@ -67,7 +68,16 @@ class HTTPClientCookieTests: XCTestCase { } func testCookieInit() { - let c = HTTPClient.Cookie(name: "key", value: "value", path: "/path", domain: "example.com", expires: Date(timeIntervalSince1970: 1_445_412_480), maxAge: 42, httpOnly: true, secure: true) + let c = HTTPClient.Cookie( + name: "key", + value: "value", + path: "/path", + domain: "example.com", + expires: Date(timeIntervalSince1970: 1_445_412_480), + maxAge: 42, + httpOnly: true, + secure: true + ) XCTAssertEqual("key", c.name) XCTAssertEqual("value", c.value) XCTAssertEqual("/path", c.path) @@ -118,17 +128,26 @@ class HTTPClientCookieTests: XCTestCase { XCTAssertNil(c?.expires) // Later values override earlier values, except if they are ignored. - c = HTTPClient.Cookie(header: "key=value; expires=Sunday, 06-Nov-94 08:49:37 GMT; expires=04/01/2022", defaultDomain: "example.com") + c = HTTPClient.Cookie( + header: "key=value; expires=Sunday, 06-Nov-94 08:49:37 GMT; expires=04/01/2022", + defaultDomain: "example.com" + ) XCTAssertEqual("key", c?.name) XCTAssertEqual("value", c?.value) XCTAssertEqual(Date(timeIntervalSince1970: 784_111_777), c?.expires) - c = HTTPClient.Cookie(header: "key=value; expires=Sunday, 06-Nov-94 08:49:37 GMT; expires=", defaultDomain: "example.com") + c = HTTPClient.Cookie( + header: "key=value; expires=Sunday, 06-Nov-94 08:49:37 GMT; expires=", + defaultDomain: "example.com" + ) XCTAssertEqual("key", c?.name) XCTAssertEqual("value", c?.value) XCTAssertEqual(Date(timeIntervalSince1970: 784_111_777), c?.expires) - c = HTTPClient.Cookie(header: "key=value; expires=Sunday, 06-Nov-94 08:49:37 GMT; expires", defaultDomain: "example.com") + c = HTTPClient.Cookie( + header: "key=value; expires=Sunday, 06-Nov-94 08:49:37 GMT; expires", + defaultDomain: "example.com" + ) XCTAssertEqual("key", c?.name) XCTAssertEqual("value", c?.value) XCTAssertEqual(Date(timeIntervalSince1970: 784_111_777), c?.expires) @@ -467,11 +486,17 @@ class HTTPClientCookieTests: XCTestCase { try XCTSkipIf(localeCheck.tm_mon != 1, "Unable to set locale") // Cookie parsing should be independent of C locale. - var c = HTTPClient.Cookie(header: "key=value; eXpIRes=Sunday, 06-Nov-94 08:49:37 GMT;", defaultDomain: "example.org") + var c = HTTPClient.Cookie( + header: "key=value; eXpIRes=Sunday, 06-Nov-94 08:49:37 GMT;", + defaultDomain: "example.org" + ) XCTAssertEqual(Date(timeIntervalSince1970: 784_111_777), c?.expires) c = HTTPClient.Cookie(header: "key=value; eXpIRes=Sun Nov 6 08:49:37 1994;", defaultDomain: "example.org")! XCTAssertEqual(Date(timeIntervalSince1970: 784_111_777), c?.expires) - c = HTTPClient.Cookie(header: "key=value; eXpIRes=Sonntag, 06-Nov-94 08:49:37 GMT;", defaultDomain: "example.org")! + c = HTTPClient.Cookie( + header: "key=value; eXpIRes=Sonntag, 06-Nov-94 08:49:37 GMT;", + defaultDomain: "example.org" + )! XCTAssertNil(c?.expires) } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInformationalResponsesTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientInformationalResponsesTests+XCTest.swift deleted file mode 100644 index 63d7f85e2..000000000 --- a/Tests/AsyncHTTPClientTests/HTTPClientInformationalResponsesTests+XCTest.swift +++ /dev/null @@ -1,32 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTPClientInformationalResponsesTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTPClientReproTests { - static var allTests: [(String, (HTTPClientReproTests) -> () throws -> Void)] { - return [ - ("testServerSends100ContinueFirst", testServerSends100ContinueFirst), - ("testServerSendsSwitchingProtocols", testServerSendsSwitchingProtocols), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInformationalResponsesTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientInformationalResponsesTests.swift index f57d5fd10..5c41a6adb 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInformationalResponsesTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInformationalResponsesTests.swift @@ -27,7 +27,10 @@ final class HTTPClientReproTests: XCTestCase { func channelRead(context: ChannelHandlerContext, data: NIOAny) { switch self.unwrapInboundIn(data) { case .head: - context.writeAndFlush(self.wrapOutboundOut(.head(.init(version: .http1_1, status: .continue))), promise: nil) + context.writeAndFlush( + self.wrapOutboundOut(.head(.init(version: .http1_1, status: .continue))), + promise: nil + ) case .body: break case .end: @@ -37,7 +40,7 @@ final class HTTPClientReproTests: XCTestCase { } } - let client = HTTPClient(eventLoopGroupProvider: .createNew) + let client = HTTPClient(eventLoopGroupProvider: .singleton) defer { XCTAssertNoThrow(try client.syncShutdown()) } let httpBin = HTTPBin(.http1_1(ssl: false, compress: false)) { _ in @@ -47,14 +50,16 @@ final class HTTPClientReproTests: XCTestCase { let body = #"{"foo": "bar"}"# var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request( - url: "http://localhost:\(httpBin.port)/", - method: .POST, - headers: [ - "Content-Type": "application/json", - ], - body: .string(body) - )) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost:\(httpBin.port)/", + method: .POST, + headers: [ + "Content-Type": "application/json" + ], + body: .string(body) + ) + ) guard let request = maybeRequest else { return XCTFail("Expected to have a request here") } var logger = Logger(label: "test") @@ -73,10 +78,14 @@ final class HTTPClientReproTests: XCTestCase { func channelRead(context: ChannelHandlerContext, data: NIOAny) { switch self.unwrapInboundIn(data) { case .head: - let head = HTTPResponseHead(version: .http1_1, status: .switchingProtocols, headers: [ - "Connection": "Upgrade", - "Upgrade": "Websocket", - ]) + let head = HTTPResponseHead( + version: .http1_1, + status: .switchingProtocols, + headers: [ + "Connection": "Upgrade", + "Upgrade": "Websocket", + ] + ) let body = context.channel.allocator.buffer(string: "foo bar") context.write(self.wrapOutboundOut(.head(head)), promise: nil) @@ -91,7 +100,7 @@ final class HTTPClientReproTests: XCTestCase { } } - let client = HTTPClient(eventLoopGroupProvider: .createNew) + let client = HTTPClient(eventLoopGroupProvider: .singleton) defer { XCTAssertNoThrow(try client.syncShutdown()) } let httpBin = HTTPBin(.http1_1(ssl: false, compress: false)) { _ in @@ -101,14 +110,16 @@ final class HTTPClientReproTests: XCTestCase { let body = #"{"foo": "bar"}"# var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request( - url: "http://localhost:\(httpBin.port)/", - method: .POST, - headers: [ - "Content-Type": "application/json", - ], - body: .string(body) - )) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost:\(httpBin.port)/", + method: .POST, + headers: [ + "Content-Type": "application/json" + ], + body: .string(body) + ) + ) guard let request = maybeRequest else { return XCTFail("Expected to have a request here") } var logger = Logger(label: "test") diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift deleted file mode 100644 index 3be2c79a6..000000000 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift +++ /dev/null @@ -1,41 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTPClientInternalTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTPClientInternalTests { - static var allTests: [(String, (HTTPClientInternalTests) -> () throws -> Void)] { - return [ - ("testProxyStreaming", testProxyStreaming), - ("testProxyStreamingFailure", testProxyStreamingFailure), - ("testRequestURITrailingSlash", testRequestURITrailingSlash), - ("testChannelAndDelegateOnDifferentEventLoops", testChannelAndDelegateOnDifferentEventLoops), - ("testResponseFutureIsOnCorrectEL", testResponseFutureIsOnCorrectEL), - ("testUncleanCloseThrows", testUncleanCloseThrows), - ("testUploadStreamingIsCalledOnTaskEL", testUploadStreamingIsCalledOnTaskEL), - ("testTaskPromiseBoundToEL", testTaskPromiseBoundToEL), - ("testConnectErrorCalloutOnCorrectEL", testConnectErrorCalloutOnCorrectEL), - ("testInternalRequestURI", testInternalRequestURI), - ("testHasSuffix", testHasSuffix), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift index eb8d523bb..634efc14c 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift @@ -12,15 +12,17 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOConcurrencyHelpers import NIOCore import NIOEmbedded +import NIOFoundationCompat import NIOHTTP1 import NIOPosix import NIOTestUtils import XCTest +@testable import AsyncHTTPClient + class HTTPClientInternalTests: XCTestCase { typealias Request = HTTPClient.Request typealias Task = HTTPClient.Task @@ -52,7 +54,7 @@ class HTTPClientInternalTests: XCTestCase { XCTAssertNoThrow(try httpBin.shutdown()) } - let body: HTTPClient.Body = .stream(length: 50) { writer in + let body: HTTPClient.Body = .stream(contentLength: 50) { writer in do { var request = try Request(url: "http://localhost:\(httpBin.port)/events/10/1") request.headers.add(name: "Accept", value: "text/event-stream") @@ -81,13 +83,13 @@ class HTTPClientInternalTests: XCTestCase { XCTAssertNoThrow(try httpBin.shutdown()) } - var body: HTTPClient.Body = .stream(length: 50) { _ in + var body: HTTPClient.Body = .stream(contentLength: 50) { _ in httpClient.eventLoopGroup.next().makeFailedFuture(HTTPClientError.invalidProxyResponse) } XCTAssertThrowsError(try httpClient.post(url: "http://localhost:\(httpBin.port)/post", body: body).wait()) - body = .stream(length: 50) { _ in + body = .stream(contentLength: 50) { _ in do { var request = try Request(url: "http://localhost:\(httpBin.port)/events/10/1") request.headers.add(name: "Accept", value: "text/event-stream") @@ -142,11 +144,30 @@ class HTTPClientInternalTests: XCTestCase { XCTAssertEqual(request12.url.uri, "/some%2Fpathsegment1/pathsegment2") } + func testURIOfRelativeURLRequest() throws { + let requestNoLeadingSlash = try Request( + url: URL( + string: "percent%2Fencoded/hello", + relativeTo: URL(string: "http://127.0.0.1")! + )! + ) + + let requestWithLeadingSlash = try Request( + url: URL( + string: "/percent%2Fencoded/hello", + relativeTo: URL(string: "http://127.0.0.1")! + )! + ) + + XCTAssertEqual(requestNoLeadingSlash.url.uri, "/percent%2Fencoded/hello") + XCTAssertEqual(requestWithLeadingSlash.url.uri, "/percent%2Fencoded/hello") + } + func testChannelAndDelegateOnDifferentEventLoops() throws { - class Delegate: HTTPClientResponseDelegate { + final class Delegate: HTTPClientResponseDelegate { typealias Response = ([Message], [Message]) - enum Message { + enum Message: Sendable { case head(HTTPResponseHead) case bodyPart(ByteBuffer) case sentRequestHead(HTTPRequestHead) @@ -155,53 +176,72 @@ class HTTPClientInternalTests: XCTestCase { case error(Error) } - var receivedMessages: [Message] = [] - var sentMessages: [Message] = [] + private struct Messages: Sendable { + var received: [Message] = [] + var sent: [Message] = [] + } + + private let messages: NIOLoopBoundBox + + var receivedMessages: [Message] { + get { + self.messages.value.received + } + set { + self.messages.value.received = newValue + } + } + var sentMessages: [Message] { + get { + self.messages.value.sent + } + set { + self.messages.value.sent = newValue + } + } private let eventLoop: EventLoop private let randoEL: EventLoop init(expectedEventLoop: EventLoop, randomOtherEventLoop: EventLoop) { self.eventLoop = expectedEventLoop self.randoEL = randomOtherEventLoop + self.messages = .makeBoxSendingValue(Messages(), eventLoop: expectedEventLoop) } func didSendRequestHead(task: HTTPClient.Task, _ head: HTTPRequestHead) { - self.eventLoop.assertInEventLoop() self.sentMessages.append(.sentRequestHead(head)) } func didSendRequestPart(task: HTTPClient.Task, _ part: IOData) { - self.eventLoop.assertInEventLoop() self.sentMessages.append(.sentRequestPart(part)) } func didSendRequest(task: HTTPClient.Task) { - self.eventLoop.assertInEventLoop() self.sentMessages.append(.sentRequest) } func didReceiveError(task: HTTPClient.Task, _ error: Error) { - self.eventLoop.assertInEventLoop() self.receivedMessages.append(.error(error)) } - public func didReceiveHead(task: HTTPClient.Task, - _ head: HTTPResponseHead) -> EventLoopFuture { - self.eventLoop.assertInEventLoop() + public func didReceiveHead( + task: HTTPClient.Task, + _ head: HTTPResponseHead + ) -> EventLoopFuture { self.receivedMessages.append(.head(head)) return self.randoEL.makeSucceededFuture(()) } - func didReceiveBodyPart(task: HTTPClient.Task, - _ buffer: ByteBuffer) -> EventLoopFuture { - self.eventLoop.assertInEventLoop() + func didReceiveBodyPart( + task: HTTPClient.Task, + _ buffer: ByteBuffer + ) -> EventLoopFuture { self.receivedMessages.append(.bodyPart(buffer)) return self.randoEL.makeSucceededFuture(()) } func didFinishRequest(task: HTTPClient.Task) throws -> Response { - self.eventLoop.assertInEventLoop() - return (self.receivedMessages, self.sentMessages) + (self.receivedMessages, self.sentMessages) } } @@ -223,7 +263,7 @@ class HTTPClientInternalTests: XCTestCase { XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) } - let body: HTTPClient.Body = .stream(length: 8) { writer in + let body: HTTPClient.Body = .stream(contentLength: 8) { writer in let buffer = ByteBuffer(string: "1234") return writer.write(.byteBuffer(buffer)).flatMap { let buffer = ByteBuffer(string: "4321") @@ -231,22 +271,38 @@ class HTTPClientInternalTests: XCTestCase { } } - let request = try Request(url: "http://127.0.0.1:\(server.serverPort)/custom", - body: body) + let request = try Request( + url: "http://127.0.0.1:\(server.serverPort)/custom", + body: body + ) let delegate = Delegate(expectedEventLoop: delegateEL, randomOtherEventLoop: randoEL) - let future = httpClient.execute(request: request, - delegate: delegate, - eventLoop: .init(.testOnly_exact(channelOn: channelEL, - delegateOn: delegateEL))).futureResult - - XCTAssertNoThrow(try server.readInbound()) // .head - XCTAssertNoThrow(try server.readInbound()) // .body - XCTAssertNoThrow(try server.readInbound()) // .end + let future = httpClient.execute( + request: request, + delegate: delegate, + eventLoop: .init( + .testOnly_exact( + channelOn: channelEL, + delegateOn: delegateEL + ) + ) + ).futureResult + + XCTAssertNoThrow(try server.readInbound()) // .head + XCTAssertNoThrow(try server.readInbound()) // .body + XCTAssertNoThrow(try server.readInbound()) // .end // Send 3 parts, but only one should be received until the future is complete - XCTAssertNoThrow(try server.writeOutbound(.head(.init(version: .init(major: 1, minor: 1), - status: .ok, - headers: HTTPHeaders([("Transfer-Encoding", "chunked")]))))) + XCTAssertNoThrow( + try server.writeOutbound( + .head( + .init( + version: .init(major: 1, minor: 1), + status: .ok, + headers: HTTPHeaders([("Transfer-Encoding", "chunked")]) + ) + ) + ) + ) let buffer = ByteBuffer(string: "1234") XCTAssertNoThrow(try server.writeOutbound(.body(.byteBuffer(buffer)))) XCTAssertNoThrow(try server.writeOutbound(.end(nil))) @@ -278,7 +334,7 @@ class HTTPClientInternalTests: XCTestCase { switch sentMessages.dropFirst(3).first { case .some(.sentRequest): - () // OK + () // OK default: XCTFail("wrong message") } @@ -316,7 +372,10 @@ class HTTPClientInternalTests: XCTestCase { let el = group.next() let req1 = client.execute(request: request, eventLoop: .delegate(on: el)) let req2 = client.execute(request: request, eventLoop: .delegateAndChannel(on: el)) - let req3 = client.execute(request: request, eventLoop: .init(.testOnly_exact(channelOn: el, delegateOn: el))) + let req3 = client.execute( + request: request, + eventLoop: .init(.testOnly_exact(channelOn: el, delegateOn: el)) + ) XCTAssert(req1.eventLoop === el) XCTAssert(req2.eventLoop === el) XCTAssert(req3.eventLoop === el) @@ -335,8 +394,8 @@ class HTTPClientInternalTests: XCTestCase { _ = httpClient.get(url: "http://localhost:\(server.serverPort)/wait") - XCTAssertNoThrow(try server.readInbound()) // .head - XCTAssertNoThrow(try server.readInbound()) // .end + XCTAssertNoThrow(try server.readInbound()) // .head + XCTAssertNoThrow(try server.readInbound()) // .end do { try httpClient.syncShutdown(requiresCleanClose: true) @@ -366,7 +425,7 @@ class HTTPClientInternalTests: XCTestCase { let el2 = group.next() XCTAssert(el1 !== el2) - let body: HTTPClient.Body = .stream(length: 8) { writer in + let body: HTTPClient.Body = .stream(contentLength: 8) { writer in XCTAssert(el1.inEventLoop) let buffer = ByteBuffer(string: "1234") return writer.write(.byteBuffer(buffer)).flatMap { @@ -376,10 +435,16 @@ class HTTPClientInternalTests: XCTestCase { } } let request = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/post", method: .POST, body: body) - let response = httpClient.execute(request: request, - delegate: ResponseAccumulator(request: request), - eventLoop: HTTPClient.EventLoopPreference(.testOnly_exact(channelOn: el2, - delegateOn: el1))) + let response = httpClient.execute( + request: request, + delegate: ResponseAccumulator(request: request), + eventLoop: HTTPClient.EventLoopPreference( + .testOnly_exact( + channelOn: el2, + delegateOn: el1 + ) + ) + ) XCTAssert(el1 === response.eventLoop) XCTAssertNoThrow(try response.wait()) } @@ -400,17 +465,25 @@ class HTTPClientInternalTests: XCTestCase { let request = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)//get") let delegate = ResponseAccumulator(request: request) - let task = client.execute(request: request, delegate: delegate, eventLoop: .init(.testOnly_exact(channelOn: el1, delegateOn: el2))) + let task = client.execute( + request: request, + delegate: delegate, + eventLoop: .init(.testOnly_exact(channelOn: el1, delegateOn: el2)) + ) XCTAssertTrue(task.futureResult.eventLoop === el2) XCTAssertNoThrow(try task.wait()) } func testConnectErrorCalloutOnCorrectEL() throws { - class TestDelegate: HTTPClientResponseDelegate { + final class TestDelegate: HTTPClientResponseDelegate { typealias Response = Void let expectedEL: EventLoop - var receivedError: Bool = false + let _receivedError = NIOLockedValueBox(false) + + var receivedError: Bool { + self._receivedError.withLockedValue { $0 } + } init(expectedEL: EventLoop) { self.expectedEL = expectedEL @@ -419,7 +492,7 @@ class HTTPClientInternalTests: XCTestCase { func didFinishRequest(task: HTTPClient.Task) throws {} func didReceiveError(task: HTTPClient.Task, _ error: Error) { - self.receivedError = true + self._receivedError.withLockedValue { $0 = true } XCTAssertTrue(self.expectedEL.inEventLoop) } } @@ -429,7 +502,9 @@ class HTTPClientInternalTests: XCTestCase { let el2 = elg.next() let httpBin = HTTPBin(.refuse) - let client = HTTPClient(eventLoopGroupProvider: .shared(elg)) + let config = HTTPClient.Configuration() + .enableFastFailureModeForTesting() + let client = HTTPClient(eventLoopGroupProvider: .shared(elg), configuration: config) defer { XCTAssertNoThrow(try client.syncShutdown()) @@ -439,7 +514,11 @@ class HTTPClientInternalTests: XCTestCase { let request = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/get") let delegate = TestDelegate(expectedEL: el1) XCTAssertNoThrow(try httpBin.shutdown()) - let task = client.execute(request: request, delegate: delegate, eventLoop: .init(.testOnly_exact(channelOn: el2, delegateOn: el1))) + let task = client.execute( + request: request, + delegate: delegate, + eventLoop: .init(.testOnly_exact(channelOn: el2, delegateOn: el1)) + ) XCTAssertThrowsError(try task.wait()) XCTAssertTrue(delegate.receivedError) } @@ -472,10 +551,13 @@ class HTTPClientInternalTests: XCTestCase { let request6 = try Request(url: "https://127.0.0.1") XCTAssertEqual(request6.deconstructedURL.scheme, .https) - XCTAssertEqual(request6.deconstructedURL.connectionTarget, .ipAddress( - serialization: "127.0.0.1", - address: try! SocketAddress(ipAddress: "127.0.0.1", port: 443) - )) + XCTAssertEqual( + request6.deconstructedURL.connectionTarget, + .ipAddress( + serialization: "127.0.0.1", + address: try! SocketAddress(ipAddress: "127.0.0.1", port: 443) + ) + ) XCTAssertEqual(request6.deconstructedURL.uri, "/") let request7 = try Request(url: "https://0x7F.1:9999") @@ -485,18 +567,24 @@ class HTTPClientInternalTests: XCTestCase { let request8 = try Request(url: "http://[::1]") XCTAssertEqual(request8.deconstructedURL.scheme, .http) - XCTAssertEqual(request8.deconstructedURL.connectionTarget, .ipAddress( - serialization: "[::1]", - address: try! SocketAddress(ipAddress: "::1", port: 80) - )) + XCTAssertEqual( + request8.deconstructedURL.connectionTarget, + .ipAddress( + serialization: "[::1]", + address: try! SocketAddress(ipAddress: "::1", port: 80) + ) + ) XCTAssertEqual(request8.deconstructedURL.uri, "/") let request9 = try Request(url: "http://[763e:61d9::6ACA:3100:6274]:4242/foo/bar?baz") XCTAssertEqual(request9.deconstructedURL.scheme, .http) - XCTAssertEqual(request9.deconstructedURL.connectionTarget, .ipAddress( - serialization: "[763e:61d9::6ACA:3100:6274]", - address: try! SocketAddress(ipAddress: "763e:61d9::6aca:3100:6274", port: 4242) - )) + XCTAssertEqual( + request9.deconstructedURL.connectionTarget, + .ipAddress( + serialization: "[763e:61d9::6ACA:3100:6274]", + address: try! SocketAddress(ipAddress: "763e:61d9::6aca:3100:6274", port: 4242) + ) + ) XCTAssertEqual(request9.deconstructedURL.uri, "/foo/bar?baz") // Some systems have quirks in their implementations of 'ntop' which cause them to write @@ -505,18 +593,24 @@ class HTTPClientInternalTests: XCTestCase { // so the serialization must be kept verbatim as it was given in the request. let request10 = try Request(url: "http://[::c0a8:1]:4242/foo/bar?baz") XCTAssertEqual(request10.deconstructedURL.scheme, .http) - XCTAssertEqual(request10.deconstructedURL.connectionTarget, .ipAddress( - serialization: "[::c0a8:1]", - address: try! SocketAddress(ipAddress: "::c0a8:1", port: 4242) - )) + XCTAssertEqual( + request10.deconstructedURL.connectionTarget, + .ipAddress( + serialization: "[::c0a8:1]", + address: try! SocketAddress(ipAddress: "::c0a8:1", port: 4242) + ) + ) XCTAssertEqual(request10.deconstructedURL.uri, "/foo/bar?baz") let request11 = try Request(url: "http://[::192.168.0.1]:4242/foo/bar?baz") XCTAssertEqual(request11.deconstructedURL.scheme, .http) - XCTAssertEqual(request11.deconstructedURL.connectionTarget, .ipAddress( - serialization: "[::192.168.0.1]", - address: try! SocketAddress(ipAddress: "::192.168.0.1", port: 4242) - )) + XCTAssertEqual( + request11.deconstructedURL.connectionTarget, + .ipAddress( + serialization: "[::192.168.0.1]", + address: try! SocketAddress(ipAddress: "::192.168.0.1", port: 4242) + ) + ) XCTAssertEqual(request11.deconstructedURL.uri, "/foo/bar?baz") } @@ -545,11 +639,56 @@ class HTTPClientInternalTests: XCTestCase { } // Empty collection. do { - let elements: Array = [] + let elements: [Int] = [] XCTAssertTrue(elements.hasSuffix([])) XCTAssertFalse(elements.hasSuffix([0])) XCTAssertFalse(elements.hasSuffix([42])) XCTAssertFalse(elements.hasSuffix([0, 0, 0])) } } + + /// test to verify that we actually share the same thread pool across all ``FileDownloadDelegate``s for a given ``HTTPClient`` + func testSharedThreadPoolIsIdenticalForAllDelegates() throws { + let httpBin = HTTPBin() + let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup)) + defer { + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) + XCTAssertNoThrow(try httpBin.shutdown()) + } + var request = try Request(url: "http://localhost:\(httpBin.port)/events/10/content-length") + request.headers.add(name: "Accept", value: "text/event-stream") + + let filePaths = (0..<10).map { _ in + TemporaryFileHelpers.makeTemporaryFilePath() + } + defer { + for filePath in filePaths { + TemporaryFileHelpers.removeTemporaryFile(at: filePath) + } + } + let delegates = try filePaths.map { + try FileDownloadDelegate(path: $0) + } + + let resultFutures = delegates.map { delegate in + httpClient.execute( + request: request, + delegate: delegate + ).futureResult + } + _ = try EventLoopFuture.whenAllSucceed(resultFutures, on: self.clientGroup.next()).wait() + + let threadPools = delegates.map { $0._fileIOThreadPool } + let firstThreadPool = threadPools.first ?? nil + XCTAssert(threadPools.dropFirst().allSatisfy { $0 === firstThreadPool }) + } +} + +extension HTTPClient.Configuration { + func enableFastFailureModeForTesting() -> Self { + var copy = self + copy.networkFrameworkWaitForConnectivity = false + copy.connectionPool.retryConnectionEstablishment = false + return copy + } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientNIOTSTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientNIOTSTests+XCTest.swift deleted file mode 100644 index cc33f6aee..000000000 --- a/Tests/AsyncHTTPClientTests/HTTPClientNIOTSTests+XCTest.swift +++ /dev/null @@ -1,35 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTPClientNIOTSTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTPClientNIOTSTests { - static var allTests: [(String, (HTTPClientNIOTSTests) -> () throws -> Void)] { - return [ - ("testCorrectEventLoopGroup", testCorrectEventLoopGroup), - ("testTLSFailError", testTLSFailError), - ("testConnectionFailError", testConnectionFailError), - ("testTLSVersionError", testTLSVersionError), - ("testTrustRootCertificateLoadFail", testTrustRootCertificateLoadFail), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientNIOTSTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientNIOTSTests.swift index 172ee89ba..4c2d24dc4 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientNIOTSTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientNIOTSTests.swift @@ -12,16 +12,19 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient -#if canImport(Network) -import Network -#endif +import NIOConcurrencyHelpers import NIOCore import NIOPosix import NIOSSL import NIOTransportServices import XCTest +@testable import AsyncHTTPClient + +#if canImport(Network) +import Network +#endif + class HTTPClientNIOTSTests: XCTestCase { var clientGroup: EventLoopGroup! @@ -37,7 +40,7 @@ class HTTPClientNIOTSTests: XCTestCase { } func testCorrectEventLoopGroup() { - let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) + let httpClient = HTTPClient(eventLoopGroupProvider: .singleton) defer { XCTAssertNoThrow(try httpClient.syncShutdown()) } @@ -54,7 +57,12 @@ class HTTPClientNIOTSTests: XCTestCase { guard isTestingNIOTS() else { return } let httpBin = HTTPBin(.http1_1(ssl: true)) - let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup)) + let config = HTTPClient.Configuration() + .enableFastFailureModeForTesting() + let httpClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: config + ) defer { XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) XCTAssertNoThrow(try httpBin.shutdown()) @@ -65,8 +73,10 @@ class HTTPClientNIOTSTests: XCTestCase { _ = try httpClient.get(url: "https://localhost:\(httpBin.port)/get").wait() XCTFail("This should have failed") } catch let error as HTTPClient.NWTLSError { - XCTAssert(error.status == errSSLHandshakeFail || error.status == errSSLBadCert, - "unexpected NWTLSError with status \(error.status)") + XCTAssert( + error.status == errSSLHandshakeFail || error.status == errSSLBadCert, + "unexpected NWTLSError with status \(error.status)" + ) } catch { XCTFail("Error should have been NWTLSError not \(type(of: error))") } @@ -75,12 +85,44 @@ class HTTPClientNIOTSTests: XCTestCase { #endif } + func testConnectionFailsFastError() { + guard isTestingNIOTS() else { return } + #if canImport(Network) + let httpBin = HTTPBin(.http1_1(ssl: false)) + let config = HTTPClient.Configuration() + .enableFastFailureModeForTesting() + + let httpClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: config + ) + + defer { + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) + } + + let port = httpBin.port + XCTAssertNoThrow(try httpBin.shutdown()) + + XCTAssertThrowsError(try httpClient.get(url: "http://localhost:\(port)/get").wait()) { + XCTAssertTrue($0 is NWError) + } + #endif + } + func testConnectionFailError() { guard isTestingNIOTS() else { return } - let httpBin = HTTPBin(.http1_1(ssl: true)) - let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(timeout: .init(connect: .milliseconds(100), - read: .milliseconds(100)))) + #if canImport(Network) + let httpBin = HTTPBin(.http1_1(ssl: false)) + let httpClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init( + timeout: .init( + connect: .milliseconds(100), + read: .milliseconds(100) + ) + ) + ) defer { XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) @@ -89,9 +131,16 @@ class HTTPClientNIOTSTests: XCTestCase { let port = httpBin.port XCTAssertNoThrow(try httpBin.shutdown()) - XCTAssertThrowsError(try httpClient.get(url: "https://localhost:\(port)/get").wait()) { - XCTAssertEqual($0 as? HTTPClientError, .connectTimeout) + XCTAssertThrowsError(try httpClient.get(url: "http://localhost:\(port)/get").wait()) { + if let httpClientError = $0 as? HTTPClientError { + XCTAssertEqual(httpClientError, .connectTimeout) + } else if let posixError = $0 as? HTTPClient.NWPOSIXError { + XCTAssertEqual(posixError.errorCode, .ECONNREFUSED) + } else { + XCTFail("unexpected error \($0)") + } } + #endif } func testTLSVersionError() { @@ -102,9 +151,12 @@ class HTTPClientNIOTSTests: XCTestCase { tlsConfig.certificateVerification = .none tlsConfig.minimumTLSVersion = .tlsv11 tlsConfig.maximumTLSVersion = .tlsv1 + + let clientConfig = HTTPClient.Configuration(tlsConfiguration: tlsConfig) + .enableFastFailureModeForTesting() let httpClient = HTTPClient( eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(tlsConfiguration: tlsConfig) + configuration: clientConfig ) defer { XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) @@ -124,7 +176,7 @@ class HTTPClientNIOTSTests: XCTestCase { var tlsConfig = TLSConfiguration.makeClientConfiguration() tlsConfig.trustRoots = .file("not/a/certificate") - XCTAssertThrowsError(try tlsConfig.getNWProtocolTLSOptions()) { error in + XCTAssertThrowsError(try tlsConfig.getNWProtocolTLSOptions(serverNameIndicatorOverride: nil)) { error in switch error { case let error as NIOSSL.NIOSSLError where error == .failedToLoadCertificate: break diff --git a/Tests/AsyncHTTPClientTests/HTTPClientRequestTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientRequestTests+XCTest.swift deleted file mode 100644 index 30d93f7de..000000000 --- a/Tests/AsyncHTTPClientTests/HTTPClientRequestTests+XCTest.swift +++ /dev/null @@ -1,43 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTPClientRequestTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTPClientRequestTests { - static var allTests: [(String, (HTTPClientRequestTests) -> () throws -> Void)] { - return [ - ("testCustomHeadersAreRespected", testCustomHeadersAreRespected), - ("testUnixScheme", testUnixScheme), - ("testHTTPUnixScheme", testHTTPUnixScheme), - ("testHTTPSUnixScheme", testHTTPSUnixScheme), - ("testGetWithoutBody", testGetWithoutBody), - ("testPostWithoutBody", testPostWithoutBody), - ("testPostWithEmptyByteBuffer", testPostWithEmptyByteBuffer), - ("testPostWithByteBuffer", testPostWithByteBuffer), - ("testPostWithSequenceOfUnknownLength", testPostWithSequenceOfUnknownLength), - ("testPostWithSequenceWithFixedLength", testPostWithSequenceWithFixedLength), - ("testPostWithRandomAccessCollection", testPostWithRandomAccessCollection), - ("testPostWithAsyncSequenceOfUnknownLength", testPostWithAsyncSequenceOfUnknownLength), - ("testPostWithAsyncSequenceWithKnownLength", testPostWithAsyncSequenceWithKnownLength), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientRequestTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientRequestTests.swift index 1ebe7e939..54467aab7 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientRequestTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientRequestTests.swift @@ -12,58 +12,75 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient +import Algorithms +import NIOConcurrencyHelpers import NIOCore +import NIOHTTP1 import XCTest +@testable import AsyncHTTPClient + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) class HTTPClientRequestTests: XCTestCase { - #if compiler(>=5.5.2) && canImport(_Concurrency) - @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) private typealias Request = HTTPClientRequest - @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) private typealias PreparedRequest = HTTPClientRequest.Prepared - #endif func testCustomHeadersAreRespected() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { var request = Request(url: "https://example.com/get") request.headers = [ - "custom-header": "custom-header-value", + "custom-header": "custom-header-value" ] var preparedRequest: PreparedRequest? XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .https, - connectionTarget: .domain(name: "example.com", port: 443), - tlsConfiguration: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .GET, - uri: "/get", - headers: [ - "host": "example.com", - "custom-header": "custom-header-value", - ] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(0) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .https, + connectionTarget: .domain(name: "example.com", port: 443), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .GET, + uri: "/get", + headers: [ + "host": "example.com", + "custom-header": "custom-header-value", + ] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(0) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, ByteBuffer()) } - #endif + } + + func testBasicAuth() { + XCTAsyncTest { + var request = Request(url: "https://example.com/get") + request.setBasicAuth(username: "foo", password: "bar") + var preparedRequest: PreparedRequest? + XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) + guard let preparedRequest = preparedRequest else { return } + XCTAssertEqual(preparedRequest.head.headers.first(name: "Authorization")!, "Basic Zm9vOmJhcg==") + } } func testUnixScheme() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { var request = Request(url: "unix://%2Fexample%2Ffolder.sock/some_path") request.headers = ["custom-header": "custom-value"] @@ -71,30 +88,37 @@ class HTTPClientRequestTests: XCTestCase { XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .unix, - connectionTarget: .unixSocket(path: "/some_path"), - tlsConfiguration: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .GET, - uri: "/", - headers: ["custom-header": "custom-value"] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(0) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .unix, + connectionTarget: .unixSocket(path: "/some_path"), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .GET, + uri: "/", + headers: ["custom-header": "custom-value"] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(0) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, ByteBuffer()) } - #endif } func testHTTPUnixScheme() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { var request = Request(url: "http+unix://%2Fexample%2Ffolder.sock/some_path") request.headers = ["custom-header": "custom-value"] @@ -102,30 +126,37 @@ class HTTPClientRequestTests: XCTestCase { XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .httpUnix, - connectionTarget: .unixSocket(path: "/example/folder.sock"), - tlsConfiguration: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .GET, - uri: "/some_path", - headers: ["custom-header": "custom-value"] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(0) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .httpUnix, + connectionTarget: .unixSocket(path: "/example/folder.sock"), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .GET, + uri: "/some_path", + headers: ["custom-header": "custom-value"] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(0) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, ByteBuffer()) } - #endif } func testHTTPSUnixScheme() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { var request = Request(url: "https+unix://%2Fexample%2Ffolder.sock/some_path") request.headers = ["custom-header": "custom-value"] @@ -133,60 +164,74 @@ class HTTPClientRequestTests: XCTestCase { XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .httpsUnix, - connectionTarget: .unixSocket(path: "/example/folder.sock"), - tlsConfiguration: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .GET, - uri: "/some_path", - headers: ["custom-header": "custom-value"] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(0) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .httpsUnix, + connectionTarget: .unixSocket(path: "/example/folder.sock"), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .GET, + uri: "/some_path", + headers: ["custom-header": "custom-value"] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(0) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, ByteBuffer()) } - #endif } func testGetWithoutBody() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { let request = Request(url: "https://example.com/get") var preparedRequest: PreparedRequest? XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .https, - connectionTarget: .domain(name: "example.com", port: 443), - tlsConfiguration: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .GET, - uri: "/get", - headers: ["host": "example.com"] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(0) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .https, + connectionTarget: .domain(name: "example.com", port: 443), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .GET, + uri: "/get", + headers: ["host": "example.com"] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(0) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, ByteBuffer()) } - #endif } func testPostWithoutBody() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { var request = Request(url: "http://example.com/post") request.method = .POST @@ -194,34 +239,41 @@ class HTTPClientRequestTests: XCTestCase { XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .http, - connectionTarget: .domain(name: "example.com", port: 80), - tlsConfiguration: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .POST, - uri: "/post", - headers: [ - "host": "example.com", - "content-length": "0", - ] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(0) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .http, + connectionTarget: .domain(name: "example.com", port: 80), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .POST, + uri: "/post", + headers: [ + "host": "example.com", + "content-length": "0", + ] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(0) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, ByteBuffer()) } - #endif } func testPostWithEmptyByteBuffer() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { var request = Request(url: "http://example.com/post") request.method = .POST @@ -230,34 +282,41 @@ class HTTPClientRequestTests: XCTestCase { XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .http, - connectionTarget: .domain(name: "example.com", port: 80), - tlsConfiguration: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .POST, - uri: "/post", - headers: [ - "host": "example.com", - "content-length": "0", - ] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(0) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .http, + connectionTarget: .domain(name: "example.com", port: 80), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .POST, + uri: "/post", + headers: [ + "host": "example.com", + "content-length": "0", + ] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(0) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, ByteBuffer()) } - #endif } func testPostWithByteBuffer() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { var request = Request(url: "http://example.com/post") request.method = .POST @@ -266,106 +325,127 @@ class HTTPClientRequestTests: XCTestCase { XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .http, - connectionTarget: .domain(name: "example.com", port: 80), - tlsConfiguration: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .POST, - uri: "/post", - headers: [ - "host": "example.com", - "content-length": "9", - ] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(9) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .http, + connectionTarget: .domain(name: "example.com", port: 80), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .POST, + uri: "/post", + headers: [ + "host": "example.com", + "content-length": "9", + ] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(9) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, .init(string: "post body")) } - #endif } func testPostWithSequenceOfUnknownLength() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { var request = Request(url: "http://example.com/post") request.method = .POST - let sequence = AnySequence(ByteBuffer(string: "post body").readableBytesView) + let sequence = AnySendableSequence(ByteBuffer(string: "post body").readableBytesView) request.body = .bytes(sequence, length: .unknown) var preparedRequest: PreparedRequest? XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .http, - connectionTarget: .domain(name: "example.com", port: 80), - tlsConfiguration: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .POST, - uri: "/post", - headers: [ - "host": "example.com", - "transfer-encoding": "chunked", - ] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .stream - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .http, + connectionTarget: .domain(name: "example.com", port: 80), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .POST, + uri: "/post", + headers: [ + "host": "example.com", + "transfer-encoding": "chunked", + ] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .stream + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, .init(string: "post body")) } - #endif } func testPostWithSequenceWithFixedLength() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { var request = Request(url: "http://example.com/post") request.method = .POST - let sequence = AnySequence(ByteBuffer(string: "post body").readableBytesView) - request.body = .bytes(sequence, length: .known(9)) + let sequence = AnySendableSequence(ByteBuffer(string: "post body").readableBytesView) + request.body = .bytes(sequence, length: .known(Int64(9))) var preparedRequest: PreparedRequest? XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .http, - connectionTarget: .domain(name: "example.com", port: 80), - tlsConfiguration: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .POST, - uri: "/post", - headers: [ - "host": "example.com", - "content-length": "9", - ] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(9) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .http, + connectionTarget: .domain(name: "example.com", port: 80), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .POST, + uri: "/post", + headers: [ + "host": "example.com", + "content-length": "9", + ] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(9) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, .init(string: "post body")) } - #endif } func testPostWithRandomAccessCollection() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { var request = Request(url: "http://example.com/post") request.method = .POST @@ -375,40 +455,47 @@ class HTTPClientRequestTests: XCTestCase { XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .http, - connectionTarget: .domain(name: "example.com", port: 80), - tlsConfiguration: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .POST, - uri: "/post", - headers: [ - "host": "example.com", - "content-length": "9", - ] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(9) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .http, + connectionTarget: .domain(name: "example.com", port: 80), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .POST, + uri: "/post", + headers: [ + "host": "example.com", + "content-length": "9", + ] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(9) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, .init(string: "post body")) } - #endif } func testPostWithAsyncSequenceOfUnknownLength() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { var request = Request(url: "http://example.com/post") request.method = .POST let asyncSequence = ByteBuffer(string: "post body") .readableBytesView - .chunked(maxChunkSize: 2) - .asAsyncSequence() + .uncheckedSendableChunks(ofCount: 2) + .async .map { ByteBuffer($0) } request.body = .stream(asyncSequence, length: .unknown) @@ -416,84 +503,258 @@ class HTTPClientRequestTests: XCTestCase { XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .http, - connectionTarget: .domain(name: "example.com", port: 80), - tlsConfiguration: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .POST, - uri: "/post", - headers: [ - "host": "example.com", - "transfer-encoding": "chunked", - ] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .stream - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .http, + connectionTarget: .domain(name: "example.com", port: 80), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .POST, + uri: "/post", + headers: [ + "host": "example.com", + "transfer-encoding": "chunked", + ] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .stream + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, .init(string: "post body")) } - #endif } func testPostWithAsyncSequenceWithKnownLength() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { var request = Request(url: "http://example.com/post") request.method = .POST let asyncSequence = ByteBuffer(string: "post body") .readableBytesView - .chunked(maxChunkSize: 2) - .asAsyncSequence() + .uncheckedSendableChunks(ofCount: 2) + .async .map { ByteBuffer($0) } - request.body = .stream(asyncSequence, length: .known(9)) + request.body = .stream(asyncSequence, length: .known(Int64(9))) var preparedRequest: PreparedRequest? XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .http, - connectionTarget: .domain(name: "example.com", port: 80), - tlsConfiguration: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .POST, - uri: "/post", - headers: [ - "host": "example.com", - "content-length": "9", - ] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(9) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .http, + connectionTarget: .domain(name: "example.com", port: 80), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .POST, + uri: "/post", + headers: [ + "host": "example.com", + "content-length": "9", + ] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(9) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, .init(string: "post body")) } - #endif + } + + func testChunkingRandomAccessCollection() async throws { + let body = try await HTTPClientRequest.Body.bytes( + Array(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize) + + Array(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize) + + Array(repeating: 2, count: bagOfBytesToByteBufferConversionChunkSize) + ).collect() + + let expectedChunks = [ + ByteBuffer(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize), + ByteBuffer(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize), + ByteBuffer(repeating: 2, count: bagOfBytesToByteBufferConversionChunkSize), + ] + + XCTAssertEqual(body, expectedChunks) + } + + func testChunkingCollection() async throws { + let body = try await HTTPClientRequest.Body.bytes( + (String(repeating: "0", count: bagOfBytesToByteBufferConversionChunkSize) + + String(repeating: "1", count: bagOfBytesToByteBufferConversionChunkSize) + + String(repeating: "2", count: bagOfBytesToByteBufferConversionChunkSize)).utf8, + length: .known(Int64(bagOfBytesToByteBufferConversionChunkSize * 3)) + ).collect() + + let expectedChunks = [ + ByteBuffer(repeating: UInt8(ascii: "0"), count: bagOfBytesToByteBufferConversionChunkSize), + ByteBuffer(repeating: UInt8(ascii: "1"), count: bagOfBytesToByteBufferConversionChunkSize), + ByteBuffer(repeating: UInt8(ascii: "2"), count: bagOfBytesToByteBufferConversionChunkSize), + ] + + XCTAssertEqual(body, expectedChunks) + } + + func testChunkingSequenceThatDoesNotImplementWithContiguousStorageIfAvailable() async throws { + let bagOfBytesToByteBufferConversionChunkSize = 8 + let body = try await HTTPClientRequest.Body._bytes( + AnySendableSequence( + Array(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize) + + Array(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize) + ), + length: .known(Int64(bagOfBytesToByteBufferConversionChunkSize * 3)), + bagOfBytesToByteBufferConversionChunkSize: bagOfBytesToByteBufferConversionChunkSize, + byteBufferMaxSize: byteBufferMaxSize + ).collect() + + let expectedChunks = [ + ByteBuffer(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize), + ByteBuffer(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize), + ] + + XCTAssertEqual(body, expectedChunks) + } + + func testChunkingSequenceFastPath() async throws { + func makeBytes() -> some Sequence & Sendable { + Array(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize) + + Array(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize) + + Array(repeating: 2, count: bagOfBytesToByteBufferConversionChunkSize) + } + let body = try await HTTPClientRequest.Body.bytes( + makeBytes(), + length: .known(Int64(bagOfBytesToByteBufferConversionChunkSize * 3)) + ).collect() + + var firstChunk = ByteBuffer(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize) + firstChunk.writeImmutableBuffer(ByteBuffer(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize)) + firstChunk.writeImmutableBuffer(ByteBuffer(repeating: 2, count: bagOfBytesToByteBufferConversionChunkSize)) + let expectedChunks = [ + firstChunk + ] + + XCTAssertEqual(body, expectedChunks) + } + + func testChunkingSequenceFastPathExceedingByteBufferMaxSize() async throws { + let bagOfBytesToByteBufferConversionChunkSize = 8 + let byteBufferMaxSize = 16 + func makeBytes() -> some Sequence & Sendable { + Array(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize) + + Array(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize) + + Array(repeating: 2, count: bagOfBytesToByteBufferConversionChunkSize) + } + let body = try await HTTPClientRequest.Body._bytes( + makeBytes(), + length: .known(Int64(bagOfBytesToByteBufferConversionChunkSize * 3)), + bagOfBytesToByteBufferConversionChunkSize: bagOfBytesToByteBufferConversionChunkSize, + byteBufferMaxSize: byteBufferMaxSize + ).collect() + + var firstChunk = ByteBuffer(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize) + firstChunk.writeImmutableBuffer(ByteBuffer(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize)) + let secondChunk = ByteBuffer(repeating: 2, count: bagOfBytesToByteBufferConversionChunkSize) + let expectedChunks = [ + firstChunk, + secondChunk, + ] + + XCTAssertEqual(body, expectedChunks) + } + + func testBodyStringChunking() throws { + let body = try HTTPClient.Body.string( + String(repeating: "0", count: bagOfBytesToByteBufferConversionChunkSize) + + String(repeating: "1", count: bagOfBytesToByteBufferConversionChunkSize) + + String(repeating: "2", count: bagOfBytesToByteBufferConversionChunkSize) + ).collect().wait() + + let expectedChunks = [ + ByteBuffer(), // We're currently emitting an empty chunk first. + ByteBuffer(repeating: UInt8(ascii: "0"), count: bagOfBytesToByteBufferConversionChunkSize), + ByteBuffer(repeating: UInt8(ascii: "1"), count: bagOfBytesToByteBufferConversionChunkSize), + ByteBuffer(repeating: UInt8(ascii: "2"), count: bagOfBytesToByteBufferConversionChunkSize), + ] + + XCTAssertEqual(body, expectedChunks) + } + + func testBodyChunkingRandomAccessCollection() throws { + let body = try HTTPClient.Body.bytes( + Array(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize) + + Array(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize) + + Array(repeating: 2, count: bagOfBytesToByteBufferConversionChunkSize) + ).collect().wait() + + let expectedChunks = [ + ByteBuffer(), // We're currently emitting an empty chunk first. + ByteBuffer(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize), + ByteBuffer(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize), + ByteBuffer(repeating: 2, count: bagOfBytesToByteBufferConversionChunkSize), + ] + + XCTAssertEqual(body, expectedChunks) + } +} + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension AsyncSequence { + func collect() async throws -> [Element] { + try await self.reduce(into: []) { $0 += CollectionOfOne($1) } + } +} + +extension HTTPClient.Body { + func collect() -> EventLoopFuture<[ByteBuffer]> { + let eelg = EmbeddedEventLoopGroup(loops: 1) + let el = eelg.next() + let body = NIOLockedValueBox<[ByteBuffer]>([]) + let writer = StreamWriter { + switch $0 { + case .byteBuffer(let byteBuffer): + body.withLockedValue { $0.append(byteBuffer) } + case .fileRegion: + fatalError("file region not supported") + } + return el.makeSucceededVoidFuture() + } + return self.stream(writer).map { _ in body.withLockedValue { $0 } } } } -#if compiler(>=5.5.2) && canImport(_Concurrency) private struct LengthMismatch: Error { - var announcedLength: Int - var actualLength: Int + var announcedLength: Int64 + var actualLength: Int64 } @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -extension Optional where Wrapped == HTTPClientRequest.Body { +extension Optional where Wrapped == HTTPClientRequest.Prepared.Body { /// Accumulates all data from `self` into a single `ByteBuffer` and checks that the user specified length matches /// the length of the accumulated data. fileprivate func read() async throws -> ByteBuffer { - switch self?.mode { + switch self { case .none: return ByteBuffer() case .byteBuffer(let buffer): @@ -501,77 +762,58 @@ extension Optional where Wrapped == HTTPClientRequest.Body { case .sequence(let announcedLength, _, let generate): let buffer = generate(ByteBufferAllocator()) if case .known(let announcedLength) = announcedLength, - announcedLength != buffer.readableBytes { - throw LengthMismatch(announcedLength: announcedLength, actualLength: buffer.readableBytes) + announcedLength != Int64(buffer.readableBytes) + { + throw LengthMismatch(announcedLength: announcedLength, actualLength: Int64(buffer.readableBytes)) } return buffer - case .asyncSequence(length: let announcedLength, let generate): + case .asyncSequence(length: let announcedLength, let makeAsyncIterator): var accumulatedBuffer = ByteBuffer() + let generate = makeAsyncIterator() while var buffer = try await generate(ByteBufferAllocator()) { accumulatedBuffer.writeBuffer(&buffer) } if case .known(let announcedLength) = announcedLength, - announcedLength != accumulatedBuffer.readableBytes { - throw LengthMismatch(announcedLength: announcedLength, actualLength: accumulatedBuffer.readableBytes) + announcedLength != Int64(accumulatedBuffer.readableBytes) + { + throw LengthMismatch( + announcedLength: announcedLength, + actualLength: Int64(accumulatedBuffer.readableBytes) + ) } return accumulatedBuffer } } } -struct ChunkedSequence: Sequence { - struct Iterator: IteratorProtocol { - fileprivate var remainingElements: Wrapped.SubSequence - fileprivate let maxChunkSize: Int - mutating func next() -> Wrapped.SubSequence? { - guard !self.remainingElements.isEmpty else { - return nil - } - let chunk = self.remainingElements.prefix(self.maxChunkSize) - self.remainingElements = self.remainingElements.dropFirst(self.maxChunkSize) - return chunk - } - } +// swift-algorithms hasn't adopted Sendable yet. By inspection ChunksOfCountCollection should be +// Sendable assuming the underlying collection is. This wrapper allows us to avoid a blanket +// preconcurrency import of the Algorithms module. +struct UncheckedSendableChunksOfCountCollection: Collection, @unchecked Sendable +where Base: Sendable { + typealias Element = Base.SubSequence + typealias Index = ChunksOfCountCollection.Index - fileprivate let wrapped: Wrapped - fileprivate let maxChunkSize: Int + private let underlying: ChunksOfCountCollection - func makeIterator() -> Iterator { - .init(remainingElements: self.wrapped[...], maxChunkSize: self.maxChunkSize) + init(_ underlying: ChunksOfCountCollection) { + self.underlying = underlying } -} -extension Collection { - /// Lazily splits `self` into `SubSequence`s with `maxChunkSize` elements. - /// - Parameter maxChunkSize: size of each chunk except the last one which can be smaller if not enough elements are remaining. - func chunked(maxChunkSize: Int) -> ChunkedSequence { - .init(wrapped: self, maxChunkSize: maxChunkSize) - } -} + var startIndex: Index { self.underlying.startIndex } + var endIndex: Index { self.underlying.endIndex } -@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -struct AsyncSequenceFromSyncSequence: AsyncSequence { - typealias Element = Wrapped.Element - struct AsyncIterator: AsyncIteratorProtocol { - fileprivate var iterator: Wrapped.Iterator - mutating func next() async throws -> Wrapped.Element? { - self.iterator.next() - } + subscript(position: Index) -> Base.SubSequence { + self.underlying[position] } - fileprivate let wrapped: Wrapped - - func makeAsyncIterator() -> AsyncIterator { - .init(iterator: self.wrapped.makeIterator()) + func index(after i: Index) -> Index { + self.underlying.index(after: i) } } -@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -extension Sequence { - /// Turns `self` into an `AsyncSequence` by wending each element of `self` asynchronously. - func asAsyncSequence() -> AsyncSequenceFromSyncSequence { - .init(wrapped: self) +extension Collection where Self: Sendable { + func uncheckedSendableChunks(ofCount count: Int) -> UncheckedSendableChunksOfCountCollection { + UncheckedSendableChunksOfCountCollection(self.chunks(ofCount: count)) } } - -#endif diff --git a/Tests/AsyncHTTPClientTests/HTTPClientResponseTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientResponseTests.swift new file mode 100644 index 000000000..7dcc4efe6 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/HTTPClientResponseTests.swift @@ -0,0 +1,64 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2023 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Logging +import NIOCore +import NIOHTTP1 +import XCTest + +@testable import AsyncHTTPClient + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +final class HTTPClientResponseTests: XCTestCase { + func testSimpleResponse() { + let response = HTTPClientResponse.expectedContentLength( + requestMethod: .GET, + headers: ["content-length": "1025"], + status: .ok + ) + XCTAssertEqual(response, 1025) + } + + func testSimpleResponseNotModified() { + let response = HTTPClientResponse.expectedContentLength( + requestMethod: .GET, + headers: ["content-length": "1025"], + status: .notModified + ) + XCTAssertEqual(response, 0) + } + + func testSimpleResponseHeadRequestMethod() { + let response = HTTPClientResponse.expectedContentLength( + requestMethod: .HEAD, + headers: ["content-length": "1025"], + status: .ok + ) + XCTAssertEqual(response, 0) + } + + func testResponseNoContentLengthHeader() { + let response = HTTPClientResponse.expectedContentLength(requestMethod: .GET, headers: [:], status: .ok) + XCTAssertEqual(response, nil) + } + + func testResponseInvalidInteger() { + let response = HTTPClientResponse.expectedContentLength( + requestMethod: .GET, + headers: ["content-length": "none"], + status: .ok + ) + XCTAssertEqual(response, nil) + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift index 230c91a2b..689b4358e 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift @@ -12,12 +12,14 @@ // //===----------------------------------------------------------------------===// -import AsyncHTTPClient +import Atomics import Foundation +import InMemoryLogging import Logging import NIOConcurrencyHelpers import NIOCore import NIOEmbedded +import NIOFoundationCompat import NIOHPACK import NIOHTTP1 import NIOHTTP2 @@ -27,8 +29,19 @@ import NIOSSL import NIOTLS import NIOTransportServices import XCTest -#if canImport(Darwin) + +@testable import AsyncHTTPClient + +#if canImport(xlocale) +import xlocale +#elseif canImport(locale_h) +import locale_h +#elseif canImport(Darwin) import Darwin +#elseif canImport(Musl) +import Musl +#elseif canImport(Android) +import Android #elseif canImport(Glibc) import Glibc #endif @@ -45,7 +58,8 @@ func isTestingNIOTS() -> Bool { func getDefaultEventLoopGroup(numberOfThreads: Int) -> EventLoopGroup { #if canImport(Network) if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *), - isTestingNIOTS() { + isTestingNIOTS() + { return NIOTSEventLoopGroup(loopCount: numberOfThreads, defaultQoS: .default) } #endif @@ -82,15 +96,13 @@ func withCLocaleSetToGerman(_ body: () throws -> Void) throws { try body() } -class TestHTTPDelegate: HTTPClientResponseDelegate { +final class TestHTTPDelegate: HTTPClientResponseDelegate { typealias Response = Void init(backpressureEventLoop: EventLoop? = nil) { - self.backpressureEventLoop = backpressureEventLoop + self.state = NIOLockedValueBox(MutableState(backpressureEventLoop: backpressureEventLoop)) } - var backpressureEventLoop: EventLoop? - enum State { case idle case head(HTTPResponseHead) @@ -99,77 +111,96 @@ class TestHTTPDelegate: HTTPClientResponseDelegate { case error(Error) } - var state = State.idle + struct MutableState: Sendable { + var state: State = .idle + var backpressureEventLoop: EventLoop? + } + + let state: NIOLockedValueBox func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { - self.state = .head(head) - return (self.backpressureEventLoop ?? task.eventLoop).makeSucceededFuture(()) + let eventLoop = self.state.withLockedValue { + $0.state = .head(head) + return ($0.backpressureEventLoop ?? task.eventLoop) + } + + return eventLoop.makeSucceededVoidFuture() } func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { - switch self.state { - case .head(let head): - self.state = .body(head, buffer) - case .body(let head, var body): - var buffer = buffer - body.writeBuffer(&buffer) - self.state = .body(head, body) - default: - preconditionFailure("expecting head or body") + let eventLoop = self.state.withLockedValue { + switch $0.state { + case .head(let head): + $0.state = .body(head, buffer) + case .body(let head, var body): + var buffer = buffer + body.writeBuffer(&buffer) + $0.state = .body(head, body) + default: + preconditionFailure("expecting head or body") + } + return ($0.backpressureEventLoop ?? task.eventLoop) } - return (self.backpressureEventLoop ?? task.eventLoop).makeSucceededFuture(()) + + return eventLoop.makeSucceededVoidFuture() } func didFinishRequest(task: HTTPClient.Task) throws {} } -class CountingDelegate: HTTPClientResponseDelegate { +final class CountingDelegate: HTTPClientResponseDelegate { typealias Response = Int - var count = 0 + private let _count = NIOLockedValueBox(0) func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { let str = buffer.getString(at: 0, length: buffer.readableBytes) if str?.starts(with: "id:") ?? false { - self.count += 1 + self._count.withLockedValue { $0 += 1 } } return task.eventLoop.makeSucceededFuture(()) } func didFinishRequest(task: HTTPClient.Task) throws -> Int { - return self.count + self._count.withLockedValue { $0 } } } -class DelayOnHeadDelegate: HTTPClientResponseDelegate { +final class DelayOnHeadDelegate: HTTPClientResponseDelegate { typealias Response = ByteBuffer let eventLoop: EventLoop - let didReceiveHead: (HTTPResponseHead, EventLoopPromise) -> Void - - private var data: ByteBuffer + let didReceiveHead: @Sendable (HTTPResponseHead, EventLoopPromise) -> Void - private var mayReceiveData = false + struct State: Sendable { + var data: ByteBuffer + var mayReceiveData = false + var expectError = false + } - private var expectError = false + private let state: NIOLockedValueBox - init(eventLoop: EventLoop, didReceiveHead: @escaping (HTTPResponseHead, EventLoopPromise) -> Void) { + init(eventLoop: EventLoop, didReceiveHead: @escaping @Sendable (HTTPResponseHead, EventLoopPromise) -> Void) { self.eventLoop = eventLoop self.didReceiveHead = didReceiveHead - self.data = ByteBuffer() + self.state = NIOLockedValueBox(State(data: ByteBuffer())) } func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { - XCTAssertFalse(self.mayReceiveData) - XCTAssertFalse(self.expectError) + self.state.withLockedValue { + XCTAssertFalse($0.mayReceiveData) + XCTAssertFalse($0.expectError) + } let promise = self.eventLoop.makePromise(of: Void.self) - promise.futureResult.whenComplete { - switch $0 { - case .success: - self.mayReceiveData = true - case .failure: - self.expectError = true + promise.futureResult.whenComplete { result in + self.state.withLockedValue { state in + switch result { + case .success: + state.mayReceiveData = true + case .failure: + state.expectError = true + } } } @@ -178,20 +209,26 @@ class DelayOnHeadDelegate: HTTPClientResponseDelegate { } func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { - XCTAssertTrue(self.mayReceiveData) - XCTAssertFalse(self.expectError) - self.data.writeImmutableBuffer(buffer) + self.state.withLockedValue { + XCTAssertTrue($0.mayReceiveData) + XCTAssertFalse($0.expectError) + $0.data.writeImmutableBuffer(buffer) + } return self.eventLoop.makeSucceededFuture(()) } func didFinishRequest(task: HTTPClient.Task) throws -> Response { - XCTAssertTrue(self.mayReceiveData) - XCTAssertFalse(self.expectError) - return self.data + self.state.withLockedValue { + XCTAssertTrue($0.mayReceiveData) + XCTAssertFalse($0.expectError) + return $0.data + } } func didReceiveError(task: HTTPClient.Task, _ error: Error) { - XCTAssertTrue(self.expectError) + self.state.withLockedValue { + XCTAssertTrue($0.expectError) + } } } @@ -212,8 +249,8 @@ enum TemporaryFileHelpers { } else { return "/tmp" } - #endif // os - #endif // targetEnvironment + #endif // os + #endif // targetEnvironment } private static func openTemporaryFile() -> (CInt, String) { @@ -233,8 +270,10 @@ enum TemporaryFileHelpers { /// /// If the temporary directory is too long to store a UNIX domain socket path, it will `chdir` into the temporary /// directory and return a short-enough path. The iOS simulator is known to have too long paths. - internal static func withTemporaryUnixDomainSocketPathName(directory: String = temporaryDirectory, - _ body: (String) throws -> T) throws -> T { + internal static func withTemporaryUnixDomainSocketPathName( + directory: String = temporaryDirectory, + _ body: (String) throws -> T + ) throws -> T { // this is racy but we're trying to create the shortest possible path so we can't add a directory... let (fd, path) = self.openTemporaryFile() close(fd) @@ -249,17 +288,21 @@ enum TemporaryFileHelpers { shortEnoughPath = path restoreSavedCWD = false } catch SocketAddressError.unixDomainSocketPathTooLong { - FileManager.default.changeCurrentDirectoryPath(URL(fileURLWithPath: path).deletingLastPathComponent().absoluteString) + _ = FileManager.default.changeCurrentDirectoryPath( + URL(fileURLWithPath: path).deletingLastPathComponent().absoluteString + ) shortEnoughPath = URL(fileURLWithPath: path).lastPathComponent restoreSavedCWD = true - print("WARNING: Path '\(path)' could not be used as UNIX domain socket path, using chdir & '\(shortEnoughPath)'") + print( + "WARNING: Path '\(path)' could not be used as UNIX domain socket path, using chdir & '\(shortEnoughPath)'" + ) } defer { if FileManager.default.fileExists(atPath: path) { try? FileManager.default.removeItem(atPath: path) } if restoreSavedCWD { - FileManager.default.changeCurrentDirectoryPath(saveCurrentDirectory) + _ = FileManager.default.changeCurrentDirectoryPath(saveCurrentDirectory) } } return try body(shortEnoughPath) @@ -282,23 +325,52 @@ enum TemporaryFileHelpers { return try body(path) } + internal static func makeTemporaryFilePath( + directory: String = temporaryDirectory + ) -> String { + let (fd, path) = self.openTemporaryFile() + close(fd) + try! FileManager.default.removeItem(atPath: path) + return path + } + + internal static func removeTemporaryFile( + at path: String + ) { + if FileManager.default.fileExists(atPath: path) { + try? FileManager.default.removeItem(atPath: path) + } + } + internal static func fileSize(path: String) throws -> Int? { - return try FileManager.default.attributesOfItem(atPath: path)[.size] as? Int + try FileManager.default.attributesOfItem(atPath: path)[.size] as? Int } internal static func fileExists(path: String) -> Bool { - return FileManager.default.fileExists(atPath: path) + FileManager.default.fileExists(atPath: path) } } enum TestTLS { static let certificate = try! NIOSSLCertificate(bytes: Array(cert.utf8), format: .pem) static let privateKey = try! NIOSSLPrivateKey(bytes: Array(key.utf8), format: .pem) + static let serverConfiguration: TLSConfiguration = .makeServerConfiguration( + certificateChain: [.certificate(TestTLS.certificate)], + privateKey: .privateKey(TestTLS.privateKey) + ) } -internal final class HTTPBin where +#if compiler(>=6.2) +typealias AHCTestSendableMetatype = SendableMetatype +#else +typealias AHCTestSendableMetatype = Any +#endif + +internal final class HTTPBin: Sendable +where RequestHandler.InboundIn == HTTPServerRequestPart, - RequestHandler.OutboundOut == HTTPServerResponsePart { + RequestHandler.OutboundOut == HTTPServerResponsePart +{ enum BindTarget { case unixDomainSocket(String) case localhostIPv4RandomPort @@ -309,19 +381,51 @@ internal final class HTTPBin where // refuses all connections case refuse // supports http1.1 connections only, which can be either plain text or encrypted - case http1_1(ssl: Bool = false, compress: Bool = false) + case http1_1( + tlsConfiguration: TLSConfiguration? = nil, + compress: Bool = false + ) // supports http1.1 and http2 connections which must be always encrypted - case http2(compress: Bool) + case http2( + tlsConfiguration: TLSConfiguration = TestTLS.serverConfiguration, + compress: Bool = false, + settings: HTTP2Settings? = nil + ) + + static func http1_1(ssl: Bool, compress: Bool = false) -> Self { + .http1_1(tlsConfiguration: ssl ? TestTLS.serverConfiguration : nil, compress: compress) + } // supports request decompression and http response compression var compress: Bool { switch self { case .refuse: return false - case .http1_1(ssl: _, compress: let compress), .http2(compress: let compress): + case .http1_1(_, let compress), .http2(_, let compress, _): return compress } } + + var httpSettings: HTTP2Settings { + switch self { + case .http1_1, .http2(_, _, nil), .refuse: + return HTTP2Connection.defaultSettings + case .http2(_, _, .some(let customSettings)): + return customSettings + } + } + + var tlsConfiguration: TLSConfiguration? { + switch self { + case .refuse: + return nil + case .http1_1(let tlsConfiguration, _): + return tlsConfiguration + case .http2(var tlsConfiguration, _, _): + tlsConfiguration.applicationProtocols = NIOHTTP2SupportedALPNProtocols + return tlsConfiguration + } + } } enum Proxy { @@ -333,32 +437,61 @@ internal final class HTTPBin where private let activeConnCounterHandler: ConnectionsCountHandler var activeConnections: Int { - return self.activeConnCounterHandler.currentlyActiveConnections + self.activeConnCounterHandler.currentlyActiveConnections } var createdConnections: Int { - return self.activeConnCounterHandler.createdConnections + self.activeConnCounterHandler.createdConnections } var port: Int { - return Int(self.serverChannel.localAddress!.port!) + self.serverChannel.withLockedValue { + Int($0!.localAddress!.port!) + } } var socketAddress: SocketAddress { - return self.serverChannel.localAddress! + self.serverChannel.withLockedValue { + $0!.localAddress! + } + } + + var baseURL: String { + let scheme: String = { + switch mode { + case .http1_1, .refuse: + return "http" + case .http2: + return "https" + } + }() + let host: String = { + switch self.socketAddress { + case .v4: + return self.socketAddress.ipAddress! + case .v6: + return "[\(self.socketAddress.ipAddress!)]" + case .unixDomainSocket: + return self.socketAddress.pathname! + } + }() + + return "\(scheme)://\(host):\(self.port)/" } private let mode: Mode private let sslContext: NIOSSLContext? - private var serverChannel: Channel! - private let isShutdown: NIOAtomic = .makeAtomic(value: false) - private let handlerFactory: (Int) -> (RequestHandler) + private let serverChannel = NIOLockedValueBox(nil) + private let isShutdown = ManagedAtomic(false) + private let handlerFactory: @Sendable (Int) -> (RequestHandler) init( _ mode: Mode = .http1_1(ssl: false, compress: false), proxy: Proxy = .none, bindTarget: BindTarget = .localhostIPv4RandomPort, - handlerFactory: @escaping (Int) -> (RequestHandler) + reusePort: Bool = false, + trafficShapingTargetBytesPerSecond: Int? = nil, + handlerFactory: @escaping @Sendable (Int) -> (RequestHandler) ) { self.mode = mode self.sslContext = HTTPBin.sslContext(for: mode) @@ -376,15 +509,26 @@ internal final class HTTPBin where self.activeConnCounterHandler = ConnectionsCountHandler() - let connectionIDAtomic = NIOAtomic.makeAtomic(value: 0) + let connectionIDAtomic = ManagedAtomic(0) - self.serverChannel = try! ServerBootstrap(group: self.group) + let serverChannel = try! ServerBootstrap(group: self.group) .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) - .serverChannelInitializer { channel in - channel.pipeline.addHandler(self.activeConnCounterHandler) + .serverChannelOption( + ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEPORT), + value: reusePort ? 1 : 0 + ) + .serverChannelInitializer { [activeConnCounterHandler] channel in + channel.pipeline.addHandler(activeConnCounterHandler) }.childChannelInitializer { channel in + if let trafficShapingTargetBytesPerSecond = trafficShapingTargetBytesPerSecond { + try! channel.pipeline.syncOperations.addHandler( + BasicInboundTrafficShapingHandler( + targetBytesPerSecond: trafficShapingTargetBytesPerSecond + ) + ) + } do { - let connectionID = connectionIDAtomic.add(1) + let connectionID = connectionIDAtomic.loadThenWrappingIncrement(ordering: .relaxed) if case .refuse = mode { throw HTTPBinError.refusedConnection @@ -418,6 +562,7 @@ internal final class HTTPBin where return channel.eventLoop.makeFailedFuture(error) } }.bind(to: socketAddress).wait() + self.serverChannel.withLockedValue { $0 = serverChannel } } private func syncAddHTTPProxyHandlers( @@ -430,18 +575,18 @@ internal final class HTTPBin where let responseEncoder = HTTPResponseEncoder() let requestDecoder = ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes)) - let proxySimulator = HTTPProxySimulator(promise: promise, expectedAuhorization: expectedAuthorization) + let proxySimulator = HTTPProxySimulator(promise: promise, expectedAuthorization: expectedAuthorization) try sync.addHandler(responseEncoder) try sync.addHandler(requestDecoder) try sync.addHandler(proxySimulator) - promise.futureResult.flatMap { _ in - channel.pipeline.removeHandler(proxySimulator) + promise.futureResult.assumeIsolated().flatMap { _ in + channel.pipeline.syncOperations.removeHandler(proxySimulator) }.flatMap { _ in - channel.pipeline.removeHandler(responseEncoder) + channel.pipeline.syncOperations.removeHandler(responseEncoder) }.flatMap { _ in - channel.pipeline.removeHandler(requestDecoder) + channel.pipeline.syncOperations.removeHandler(requestDecoder) }.whenComplete { result in switch result { case .failure: @@ -484,30 +629,8 @@ internal final class HTTPBin where } } - private static func tlsConfiguration(for mode: Mode) -> TLSConfiguration? { - var configuration: TLSConfiguration? - - switch mode { - case .refuse, .http1_1(ssl: false, compress: _): - break - case .http2: - configuration = .makeServerConfiguration( - certificateChain: [.certificate(TestTLS.certificate)], - privateKey: .privateKey(TestTLS.privateKey) - ) - configuration!.applicationProtocols = NIOHTTP2SupportedALPNProtocols - case .http1_1(ssl: true, compress: _): - configuration = .makeServerConfiguration( - certificateChain: [.certificate(TestTLS.certificate)], - privateKey: .privateKey(TestTLS.privateKey) - ) - } - - return configuration - } - private static func sslContext(for mode: Mode) -> NIOSSLContext? { - if let tlsConfiguration = self.tlsConfiguration(for: mode) { + if let tlsConfiguration = mode.tlsConfiguration { return try! NIOSSLContext(configuration: tlsConfiguration) } return nil @@ -524,20 +647,20 @@ internal final class HTTPBin where // Successful upgrade to HTTP/2. Let the user configure the pipeline. let http2Handler = NIOHTTP2Handler( mode: .server, - initialSettings: [ - // TODO: make max concurrent streams configurable - HTTP2Setting(parameter: .maxConcurrentStreams, value: 10), - HTTP2Setting(parameter: .maxHeaderListSize, value: HPACKDecoder.defaultMaxHeaderListSize), - ] + initialSettings: self.mode.httpSettings ) let multiplexer = HTTP2StreamMultiplexer( mode: .server, channel: channel, + targetWindowSize: 16 * 1024 * 1024, // 16 MiB inboundStreamInitializer: { channel in do { let sync = channel.pipeline.syncOperations try sync.addHandler(HTTP2FramePayloadToHTTP1ServerCodec()) + if self.mode.compress { + try sync.addHandler(HTTPResponseCompressor()) + } try sync.addHandler(self.handlerFactory(connectionID)) return channel.eventLoop.makeSucceededVoidFuture() @@ -567,17 +690,17 @@ internal final class HTTPBin where } } + try channel.pipeline.syncOperations.addHandler(sslHandler) try channel.pipeline.syncOperations.addHandler(alpnHandler) - try channel.pipeline.syncOperations.addHandler(sslHandler, position: .before(alpnHandler)) } func shutdown() throws { - self.isShutdown.store(true) + self.isShutdown.store(true, ordering: .relaxed) try self.group.syncShutdownGracefully() } deinit { - assert(self.isShutdown.load(), "HTTPBin not shutdown before deinit") + assert(self.isShutdown.load(ordering: .relaxed), "HTTPBin not shutdown before deinit") } } @@ -585,9 +708,17 @@ extension HTTPBin where RequestHandler == HTTPBinHandler { convenience init( _ mode: Mode = .http1_1(ssl: false, compress: false), proxy: Proxy = .none, - bindTarget: BindTarget = .localhostIPv4RandomPort + bindTarget: BindTarget = .localhostIPv4RandomPort, + reusePort: Bool = false, + trafficShapingTargetBytesPerSecond: Int? = nil ) { - self.init(mode, proxy: proxy, bindTarget: bindTarget) { HTTPBinHandler(connectionID: $0) } + self.init( + mode, + proxy: proxy, + bindTarget: bindTarget, + reusePort: reusePort, + trafficShapingTargetBytesPerSecond: trafficShapingTargetBytesPerSecond + ) { HTTPBinHandler(connectionID: $0) } } } @@ -603,14 +734,18 @@ final class HTTPProxySimulator: ChannelInboundHandler, RemovableChannelHandler { // the promise to succeed, once the proxy connection is setup let promise: EventLoopPromise - let expectedAuhorization: String? + let expectedAuthorization: String? var head: HTTPResponseHead - init(promise: EventLoopPromise, expectedAuhorization: String?) { + init(promise: EventLoopPromise, expectedAuthorization: String?) { self.promise = promise - self.expectedAuhorization = expectedAuhorization - self.head = HTTPResponseHead(version: .init(major: 1, minor: 1), status: .ok, headers: .init([("Content-Length", "0")])) + self.expectedAuthorization = expectedAuthorization + self.head = HTTPResponseHead( + version: .init(major: 1, minor: 1), + status: .ok, + headers: .init([("Content-Length", "0")]) + ) } func channelRead(context: ChannelHandlerContext, data: NIOAny) { @@ -622,9 +757,10 @@ final class HTTPProxySimulator: ChannelInboundHandler, RemovableChannelHandler { return } - if let expectedAuhorization = self.expectedAuhorization { + if let expectedAuthorization = self.expectedAuthorization { guard let authorization = head.headers["proxy-authorization"].first, - expectedAuhorization == authorization else { + expectedAuthorization == authorization + else { self.head.status = .proxyAuthenticationRequired return } @@ -648,12 +784,31 @@ final class HTTPProxySimulator: ChannelInboundHandler, RemovableChannelHandler { internal struct HTTPResponseBuilder { var head: HTTPResponseHead var body: ByteBuffer? + var requestBodyByteCount: Int + let responseBodyIsRequestBodyByteCount: Bool - init(_ version: HTTPVersion = HTTPVersion(major: 1, minor: 1), status: HTTPResponseStatus, headers: HTTPHeaders = HTTPHeaders()) { + init( + _ version: HTTPVersion = HTTPVersion(major: 1, minor: 1), + status: HTTPResponseStatus, + headers: HTTPHeaders = HTTPHeaders(), + responseBodyIsRequestBodyByteCount: Bool = false + ) { self.head = HTTPResponseHead(version: version, status: status, headers: headers) + self.requestBodyByteCount = 0 + self.responseBodyIsRequestBodyByteCount = responseBodyIsRequestBodyByteCount } mutating func add(_ part: ByteBuffer) { + self.requestBodyByteCount += part.readableBytes + guard !self.responseBodyIsRequestBodyByteCount else { + if self.body == nil { + self.body = ByteBuffer() + self.body!.reserveCapacity(100) + } + self.body!.clear() + self.body!.writeString("\(self.requestBodyByteCount)") + return + } if var body = body { var part = part body.writeBuffer(&part) @@ -701,8 +856,10 @@ internal final class HTTPBinHandler: ChannelInboundHandler { for header in head.headers { let needle = "x-send-back-header-" if header.name.lowercased().starts(with: needle) { - self.responseHeaders.add(name: String(header.name.dropFirst(needle.count)), - value: header.value) + self.responseHeaders.add( + name: String(header.name.dropFirst(needle.count)), + value: header.value + ) } } } @@ -715,7 +872,12 @@ internal final class HTTPBinHandler: ChannelInboundHandler { headers = HTTPHeaders() } - context.write(wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok, headers: headers))), promise: nil) + context.write( + wrapOutboundOut( + .head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok, headers: headers)) + ), + promise: nil + ) for i in 0..<10 { let msg = "id: \(i)" var buf = context.channel.allocator.buffer(capacity: msg.count) @@ -730,7 +892,12 @@ internal final class HTTPBinHandler: ChannelInboundHandler { // This tests receiving chunks very fast: please do not insert delays here! let headers = HTTPHeaders([("Transfer-Encoding", "chunked")]) - context.write(self.wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok, headers: headers))), promise: nil) + context.write( + self.wrapOutboundOut( + .head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok, headers: headers)) + ), + promise: nil + ) for i in 0..<10 { let msg = "id: \(i)" var buf = context.channel.allocator.buffer(capacity: msg.count) @@ -741,6 +908,27 @@ internal final class HTTPBinHandler: ChannelInboundHandler { context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) } + func writeManyChunks(context: ChannelHandlerContext) { + // This tests receiving a lot of tiny chunks: they must all be sent in a single flush or the test doesn't work. + let headers = HTTPHeaders([("Transfer-Encoding", "chunked")]) + + context.write( + self.wrapOutboundOut( + .head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok, headers: headers)) + ), + promise: nil + ) + let message = ByteBuffer(integer: UInt8(ascii: "a")) + + // This number (10k) is load-bearing and a bit magic: it has been experimentally verified as being sufficient to blow the stack + // in the old implementation on all testing platforms. Please don't change it without good reason. + for _ in 0..<10_000 { + context.write(wrapOutboundOut(.body(.byteBuffer(message))), promise: nil) + } + + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + } + func channelRead(context: ChannelHandlerContext, data: NIOAny) { self.isServingRequest = true switch self.unwrapInboundIn(data) { @@ -789,6 +977,13 @@ internal final class HTTPBinHandler: ChannelInboundHandler { } self.resps.append(HTTPResponseBuilder(status: .ok)) return + case "/post-respond-with-byte-count": + if req.method != .POST { + self.resps.append(HTTPResponseBuilder(status: .methodNotAllowed)) + return + } + self.resps.append(HTTPResponseBuilder(status: .ok, responseBodyIsRequestBodyByteCount: true)) + return case "/redirect/302": var headers = self.responseHeaders headers.add(name: "location", value: "/ok") @@ -849,9 +1044,12 @@ internal final class HTTPBinHandler: ChannelInboundHandler { context.close(promise: nil) return case "/custom": - context.writeAndFlush(wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok))), promise: nil) + context.writeAndFlush( + wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok))), + promise: nil + ) return - case "/events/10/1": // TODO: parse path + case "/events/10/1": // TODO: parse path self.writeEvents(context: context) return case "/events/10/content-length": @@ -859,6 +1057,9 @@ internal final class HTTPBinHandler: ChannelInboundHandler { case "/chunked": self.writeChunked(context: context) return + case "/mega-chunked": + self.writeManyChunks(context: context) + return case "/close-on-response": var headers = self.responseHeaders headers.replaceOrAdd(name: "connection", value: "close") @@ -869,8 +1070,23 @@ internal final class HTTPBinHandler: ChannelInboundHandler { // We're forcing this closed now. self.shouldClose = true self.resps.append(builder) + case "/content-length-without-body": + var headers = self.responseHeaders + headers.replaceOrAdd(name: "content-length", value: "1234") + context.writeAndFlush( + wrapOutboundOut( + .head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok, headers: headers)) + ), + promise: nil + ) + return default: - context.write(wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .notFound))), promise: nil) + context.write( + wrapOutboundOut( + .head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .notFound)) + ), + promise: nil + ) context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) return } @@ -889,32 +1105,41 @@ internal final class HTTPBinHandler: ChannelInboundHandler { response.head.headers.add(contentsOf: self.responseHeaders) context.write(wrapOutboundOut(.head(response.head)), promise: nil) if let body = response.body { - let requestInfo = RequestInfo(data: String(buffer: body), - requestNumber: self.requestId, - connectionNumber: self.connectionID) - let responseBody = try! JSONEncoder().encodeAsByteBuffer(requestInfo, - allocator: context.channel.allocator) + let requestInfo = RequestInfo( + data: String(buffer: body), + requestNumber: self.requestId, + connectionNumber: self.connectionID + ) + let responseBody = try! JSONEncoder().encodeAsByteBuffer( + requestInfo, + allocator: context.channel.allocator + ) context.write(wrapOutboundOut(.body(.byteBuffer(responseBody))), promise: nil) } else { - let requestInfo = RequestInfo(data: "", - requestNumber: self.requestId, - connectionNumber: self.connectionID) - let responseBody = try! JSONEncoder().encodeAsByteBuffer(requestInfo, - allocator: context.channel.allocator) + let requestInfo = RequestInfo( + data: "", + requestNumber: self.requestId, + connectionNumber: self.connectionID + ) + let responseBody = try! JSONEncoder().encodeAsByteBuffer( + requestInfo, + allocator: context.channel.allocator + ) context.write(wrapOutboundOut(.body(.byteBuffer(responseBody))), promise: nil) } - context.eventLoop.scheduleTask(in: self.delay) { + context.eventLoop.assumeIsolated().scheduleTask(in: self.delay) { guard context.channel.isActive else { context.close(promise: nil) return } - context.writeAndFlush(self.wrapOutboundOut(.end(nil))).whenComplete { result in + context.writeAndFlush(self.wrapOutboundOut(.end(nil))).assumeIsolated().whenComplete { result in self.isServingRequest = false switch result { case .success: - if self.responseHeaders[canonicalForm: "X-Close-Connection"].contains("true") || - self.shouldClose { + if self.responseHeaders[canonicalForm: "X-Close-Connection"].contains("true") + || self.shouldClose + { context.close(promise: nil) } case .failure(let error): @@ -943,27 +1168,27 @@ internal final class HTTPBinHandler: ChannelInboundHandler { } } -final class ConnectionsCountHandler: ChannelInboundHandler { +final class ConnectionsCountHandler: ChannelInboundHandler, Sendable { typealias InboundIn = Channel - private let activeConns = NIOAtomic.makeAtomic(value: 0) - private let createdConns = NIOAtomic.makeAtomic(value: 0) + private let activeConns = ManagedAtomic(0) + private let createdConns = ManagedAtomic(0) var createdConnections: Int { - self.createdConns.load() + self.createdConns.load(ordering: .relaxed) } var currentlyActiveConnections: Int { - self.activeConns.load() + self.activeConns.load(ordering: .relaxed) } func channelRead(context: ChannelHandlerContext, data: NIOAny) { let channel = self.unwrapInboundIn(data) - _ = self.activeConns.add(1) - _ = self.createdConns.add(1) - channel.closeFuture.whenComplete { _ in - _ = self.activeConns.sub(1) + _ = self.activeConns.loadThenWrappingIncrement(ordering: .relaxed) + _ = self.createdConns.loadThenWrappingIncrement(ordering: .relaxed) + channel.closeFuture.whenComplete { [activeConns] _ in + _ = activeConns.loadThenWrappingDecrement(ordering: .relaxed) } context.fireChannelRead(data) @@ -983,7 +1208,7 @@ internal final class CloseWithoutClosingServerHandler: ChannelInboundHandler { func handlerAdded(context: ChannelHandlerContext) { self.onClosePromise = context.eventLoop.makePromise() - self.onClosePromise!.futureResult.whenSuccess(self.callback!) + self.onClosePromise!.futureResult.assumeIsolated().whenSuccess(self.callback!) self.callback = nil } @@ -1017,9 +1242,35 @@ internal final class CloseWithoutClosingServerHandler: ChannelInboundHandler { } } +final class ExpectClosureServerHandler: ChannelInboundHandler { + typealias InboundIn = HTTPServerRequestPart + typealias OutboundOut = HTTPServerResponsePart + + private let onClosePromise: EventLoopPromise + + init(onClosePromise: EventLoopPromise) { + self.onClosePromise = onClosePromise + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + switch self.unwrapInboundIn(data) { + case .head: + let head = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["Content-Length": "0"]) + context.write(self.wrapOutboundOut(.head(head)), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + case .body, .end: + () + } + } + + func channelInactive(context: ChannelHandlerContext) { + self.onClosePromise.succeed(()) + } +} + struct EventLoopFutureTimeoutError: Error {} -extension EventLoopFuture { +extension EventLoopFuture where Value: Sendable { func timeout(after failDelay: TimeAmount) -> EventLoopFuture { let promise = self.eventLoop.makePromise(of: Value.self) @@ -1040,57 +1291,21 @@ extension EventLoopFuture { } } -struct CollectEverythingLogHandler: LogHandler { - var metadata: Logger.Metadata = [:] - var logLevel: Logger.Level = .info - let logStore: LogStore - - class LogStore { - struct Entry { - var level: Logger.Level - var message: String - var metadata: [String: String] - } - - var lock = Lock() - var logs: [Entry] = [] - - var allEntries: [Entry] { - get { - return self.lock.withLock { self.logs } - } - set { - self.lock.withLock { self.logs = newValue } - } - } - - func append(level: Logger.Level, message: Logger.Message, metadata: Logger.Metadata?) { - self.lock.withLock { - self.logs.append(Entry(level: level, - message: message.description, - metadata: metadata?.mapValues { $0.description } ?? [:])) +extension InMemoryLogHandler { + static func makeLogger( + logLevel: Logger.Level = .info, + function: String = #function + ) -> (InMemoryLogHandler, Logger) { + let handler = InMemoryLogHandler() + + var logger = Logger( + label: "\(function)", + factory: { _ in + handler } - } - } - - init(logStore: LogStore) { - self.logStore = logStore - } - - func log(level: Logger.Level, - message: Logger.Message, - metadata: Logger.Metadata?, - file: String, function: String, line: UInt) { - self.logStore.append(level: level, message: message, metadata: self.metadata.merging(metadata ?? [:]) { $1 }) - } - - subscript(metadataKey key: String) -> Logger.Metadata.Value? { - get { - return self.metadata[key] - } - set { - self.metadata[key] = newValue - } + ) + logger.logLevel = logLevel + return (handler, logger) } } @@ -1098,10 +1313,10 @@ struct CollectEverythingLogHandler: LogHandler { /// consume the bytes by calling ``next()`` on the delegate. /// /// The sole purpose of this class is to enable straight-line stream tests. -class ResponseStreamDelegate: HTTPClientResponseDelegate { +final class ResponseStreamDelegate: HTTPClientResponseDelegate { typealias Response = Void - enum State { + enum State: Sendable { /// The delegate is in the idle state. There are no http response parts to be buffered /// and the consumer did not signal a demand. Transitions to all other states are allowed. case idle @@ -1119,10 +1334,11 @@ class ResponseStreamDelegate: HTTPClientResponseDelegate { } let eventLoop: EventLoop - private var state: State = .idle + private let state: NIOLoopBoundBox init(eventLoop: EventLoop) { self.eventLoop = eventLoop + self.state = .makeBoxSendingValue(.idle, eventLoop: eventLoop) } func next() -> EventLoopFuture { @@ -1136,25 +1352,25 @@ class ResponseStreamDelegate: HTTPClientResponseDelegate { } private func next0() -> EventLoopFuture { - switch self.state { + switch self.state.value { case .idle: let promise = self.eventLoop.makePromise(of: ByteBuffer?.self) - self.state = .waitingForBytes(promise) + self.state.value = .waitingForBytes(promise) return promise.futureResult case .buffering(let byteBuffer, done: false): - self.state = .idle + self.state.value = .idle return self.eventLoop.makeSucceededFuture(byteBuffer) case .buffering(let byteBuffer, done: true): - self.state = .finished + self.state.value = .finished return self.eventLoop.makeSucceededFuture(byteBuffer) case .waitingForBytes: preconditionFailure("Don't call `.next` twice") case .failed(let error): - self.state = .finished + self.state.value = .finished return self.eventLoop.makeFailedFuture(error) case .finished: @@ -1184,16 +1400,16 @@ class ResponseStreamDelegate: HTTPClientResponseDelegate { func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { self.eventLoop.preconditionInEventLoop() - switch self.state { + switch self.state.value { case .idle: - self.state = .buffering(buffer, done: false) + self.state.value = .buffering(buffer, done: false) case .waitingForBytes(let promise): - self.state = .idle + self.state.value = .idle promise.succeed(buffer) case .buffering(var byteBuffer, done: false): var buffer = buffer byteBuffer.writeBuffer(&buffer) - self.state = .buffering(byteBuffer, done: false) + self.state.value = .buffering(byteBuffer, done: false) case .buffering(_, done: true), .finished, .failed: preconditionFailure("Invalid state: \(self.state)") } @@ -1204,14 +1420,14 @@ class ResponseStreamDelegate: HTTPClientResponseDelegate { func didReceiveError(task: HTTPClient.Task, _ error: Error) { self.eventLoop.preconditionInEventLoop() - switch self.state { + switch self.state.value { case .idle: - self.state = .failed(error) + self.state.value = .failed(error) case .waitingForBytes(let promise): - self.state = .finished + self.state.value = .finished promise.fail(error) case .buffering(_, done: false): - self.state = .failed(error) + self.state.value = .failed(error) case .buffering(_, done: true), .finished, .failed: preconditionFailure("Invalid state: \(self.state)") } @@ -1220,14 +1436,14 @@ class ResponseStreamDelegate: HTTPClientResponseDelegate { func didFinishRequest(task: HTTPClient.Task) throws { self.eventLoop.preconditionInEventLoop() - switch self.state { + switch self.state.value { case .idle: - self.state = .finished + self.state.value = .finished case .waitingForBytes(let promise): - self.state = .finished + self.state.value = .finished promise.succeed(nil) case .buffering(let byteBuffer, done: false): - self.state = .buffering(byteBuffer, done: true) + self.state.value = .buffering(byteBuffer, done: true) case .buffering(_, done: true), .finished, .failed: preconditionFailure("Invalid state: \(self.state)") } @@ -1242,63 +1458,211 @@ class HTTPEchoHandler: ChannelInboundHandler { let request = self.unwrapInboundIn(data) switch request { case .head(let requestHead): - context.writeAndFlush(self.wrapOutboundOut(.head(.init(version: .http1_1, status: .ok, headers: requestHead.headers))), promise: nil) + context.writeAndFlush( + self.wrapOutboundOut(.head(.init(version: .http1_1, status: .ok, headers: requestHead.headers))), + promise: nil + ) case .body(let bytes): context.writeAndFlush(self.wrapOutboundOut(.body(.byteBuffer(bytes))), promise: nil) case .end: - context.writeAndFlush(self.wrapOutboundOut(.end(nil))).whenSuccess { + context.writeAndFlush(self.wrapOutboundOut(.end(nil))).assumeIsolated().whenSuccess { context.close(promise: nil) } } } } +final class HTTPEchoHeaders: ChannelInboundHandler { + typealias InboundIn = HTTPServerRequestPart + typealias OutboundOut = HTTPServerResponsePart + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let request = self.unwrapInboundIn(data) + switch request { + case .head(let requestHead): + context.writeAndFlush( + self.wrapOutboundOut(.head(.init(version: .http1_1, status: .ok, headers: requestHead.headers))), + promise: nil + ) + case .body: + break + case .end: + context.writeAndFlush(self.wrapOutboundOut(.end(nil))).assumeIsolated().whenSuccess { + context.close(promise: nil) + } + } + } +} + +final class HTTP200DelayedHandler: ChannelInboundHandler { + typealias InboundIn = HTTPServerRequestPart + typealias OutboundOut = HTTPServerResponsePart + + var pendingBodyParts: Int? + + init(bodyPartsBeforeResponse: Int) { + self.pendingBodyParts = bodyPartsBeforeResponse + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let request = self.unwrapInboundIn(data) + switch request { + case .head: + // Once we have received one response, all further requests are responded to immediately. + if self.pendingBodyParts == nil { + context.writeAndFlush(self.wrapOutboundOut(.head(.init(version: .http1_1, status: .ok))), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + } + case .body: + if let pendingBodyParts = self.pendingBodyParts { + if pendingBodyParts > 0 { + self.pendingBodyParts = pendingBodyParts - 1 + } else { + self.pendingBodyParts = nil + context.writeAndFlush( + self.wrapOutboundOut(.head(.init(version: .http1_1, status: .ok))), + promise: nil + ) + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + } + } + case .end: + break + } + } +} + private let cert = """ ------BEGIN CERTIFICATE----- -MIICmDCCAYACCQCPC8JDqMh1zzANBgkqhkiG9w0BAQsFADANMQswCQYDVQQGEwJ1 -czAgFw0xODEwMzExNTU1MjJaGA8yMTE4MTAwNzE1NTUyMlowDTELMAkGA1UEBhMC -dXMwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDiC+TGmbSP/nWWN1tj -yNfnWCU5ATjtIOfdtP6ycx8JSeqkvyNXG21kNUn14jTTU8BglGL2hfVpCbMisUdb -d3LpP8unSsvlOWwORFOViSy4YljSNM/FNoMtavuITA/sEELYgjWkz2o/uHPZHud9 -+JQwGJgqIlMa3mr2IaaUZlWN3D1u88bzJYhpt3YyxRy9+OEoOKy36KdWwhKzV3S8 -kXb0Y1GbAo68jJ9RfzeLy290mIs9qG2y1CNXWO6sxf6B//LaalizZiCfzYAVKcNR -9oNYsEJc5KB/+DsAGTzR7mL+oiU4h/vwVb2GTDat5C+PFGi6j1ujxYTRPO538ljg -dslnAgMBAAEwDQYJKoZIhvcNAQELBQADggEBAFYhA7sw8odOsRO8/DUklBOjPnmn -a078oSumgPXXw6AgcoAJv/Qthjo6CCEtrjYfcA9jaBw9/Tii7mDmqDRS5c9ZPL8+ -NEPdHjFCFBOEvlL6uHOgw0Z9Wz+5yCXnJ8oNUEgc3H2NbbzJF6sMBXSPtFS2NOK8 -OsAI9OodMrDd6+lwljrmFoCCkJHDEfE637IcsbgFKkzhO/oNCRK6OrudG4teDahz -Au4LoEYwT730QKC/VQxxEVZobjn9/sTrq9CZlbPYHxX4fz6e00sX7H9i49vk9zQ5 -5qCm9ljhrQPSa42Q62PPE2BEEGSP2KBm0J+H3vlvCD6+SNc/nMZjrRmgjrI= ------END CERTIFICATE----- -""" + -----BEGIN CERTIFICATE----- + MIICmDCCAYACCQCPC8JDqMh1zzANBgkqhkiG9w0BAQsFADANMQswCQYDVQQGEwJ1 + czAgFw0xODEwMzExNTU1MjJaGA8yMTE4MTAwNzE1NTUyMlowDTELMAkGA1UEBhMC + dXMwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDiC+TGmbSP/nWWN1tj + yNfnWCU5ATjtIOfdtP6ycx8JSeqkvyNXG21kNUn14jTTU8BglGL2hfVpCbMisUdb + d3LpP8unSsvlOWwORFOViSy4YljSNM/FNoMtavuITA/sEELYgjWkz2o/uHPZHud9 + +JQwGJgqIlMa3mr2IaaUZlWN3D1u88bzJYhpt3YyxRy9+OEoOKy36KdWwhKzV3S8 + kXb0Y1GbAo68jJ9RfzeLy290mIs9qG2y1CNXWO6sxf6B//LaalizZiCfzYAVKcNR + 9oNYsEJc5KB/+DsAGTzR7mL+oiU4h/vwVb2GTDat5C+PFGi6j1ujxYTRPO538ljg + dslnAgMBAAEwDQYJKoZIhvcNAQELBQADggEBAFYhA7sw8odOsRO8/DUklBOjPnmn + a078oSumgPXXw6AgcoAJv/Qthjo6CCEtrjYfcA9jaBw9/Tii7mDmqDRS5c9ZPL8+ + NEPdHjFCFBOEvlL6uHOgw0Z9Wz+5yCXnJ8oNUEgc3H2NbbzJF6sMBXSPtFS2NOK8 + OsAI9OodMrDd6+lwljrmFoCCkJHDEfE637IcsbgFKkzhO/oNCRK6OrudG4teDahz + Au4LoEYwT730QKC/VQxxEVZobjn9/sTrq9CZlbPYHxX4fz6e00sX7H9i49vk9zQ5 + 5qCm9ljhrQPSa42Q62PPE2BEEGSP2KBm0J+H3vlvCD6+SNc/nMZjrRmgjrI= + -----END CERTIFICATE----- + """ private let key = """ ------BEGIN PRIVATE KEY----- -MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDiC+TGmbSP/nWW -N1tjyNfnWCU5ATjtIOfdtP6ycx8JSeqkvyNXG21kNUn14jTTU8BglGL2hfVpCbMi -sUdbd3LpP8unSsvlOWwORFOViSy4YljSNM/FNoMtavuITA/sEELYgjWkz2o/uHPZ -Hud9+JQwGJgqIlMa3mr2IaaUZlWN3D1u88bzJYhpt3YyxRy9+OEoOKy36KdWwhKz -V3S8kXb0Y1GbAo68jJ9RfzeLy290mIs9qG2y1CNXWO6sxf6B//LaalizZiCfzYAV -KcNR9oNYsEJc5KB/+DsAGTzR7mL+oiU4h/vwVb2GTDat5C+PFGi6j1ujxYTRPO53 -8ljgdslnAgMBAAECggEBANZNWFNAnYJ2R5xmVuo/GxFk68Ujd4i4TZpPYbhkk+QG -g8I0w5htlEQQkVHfZx2CpTvq8feuAH/YhlA5qeD5WaPwq26q5qsmyV6tQGDgb9lO -w85l6ySZDbwdVOJe2il/MSB6MclSKvTGNm59chJnfHYsmvY3HHq4qsc2F+tRKYMW -pY75LgEbaTUV69J3cbC1wAeVjv0q/krND+YkhYpTxNZhbazK/FHOCvY+zFu9fg0L -zpwbn5fb6wIvqG7tXp7koa3QMn64AXmO/fb5mBd8G2vBGYnxwb7Egwdg/3Dw+BXu -ynQLP7ixWsE2KNfR9Ce1i3YvEo6QDTv2340I3dntxkECgYEA9vdaL4PGyvEbpim4 -kqz1vuug8Iq0nTVDo6jmgH1o+XdcIbW3imXtgi5zUJpj4oDD7/4aufiJZjG64i/v -phe11xeUvh5QNNOzeMymVDoJut97F97KKKTv7bG8Rpon/WzH2I0SoAkECCwmdWAJ -H3nvOCnXEkpbCqmIUvHVURPRDn8CgYEA6lCk3EzFQlbXs3Sj5op61R3Mscx7/35A -eGv5axzbENHt1so+s3Zvyyi1bo4VBcwnKVCvQjmTuLiqrc9VfX8XdbiTUNnEr2u3 -992Ja6DEJTZ9gy5WiviwYnwU2HpjwOVNBb17T0NLoRHkDZ6iXj7NZgwizOki5p3j -/hS0pObSIRkCgYEAiEdOGNIarHoHy9VR6H5QzR2xHYssx2NRA8p8B4MsnhxjVqaz -tUcxnJiNQXkwjRiJBrGthdnD2ASxH4dcMsb6rMpyZcbMc5ouewZS8j9khx4zCqUB -4RPC4eMmBb+jOZEBZlnSYUUYWHokbrij0B61BsTvzUQCoQuUElEoaSkKP3kCgYEA -mwdqXHvK076jjo9w1drvtEu4IDc8H2oH++TsrEr2QiWzaDZ9z71f8BnqGNCW5jQS -AQrqOjXgIArGmqMgXB0Xh4LsrUS4Fpx9ptiD0JsYy8pGtuGUzvQFt9OC80ve7kSI -dnDMwj+zLUmqCrzXjuWcfpUu/UaPGeiDbZuDfcteYhkCgYBLyL5JY7Qd4gVQIhFX -7Sv3sNJN3KZCQHEzut7IwojaxgpuxiFvgsoXXuYolVCQp32oWbYcE2Yke+hOKsTE -sCMAWZiSGN2Nrfea730IYAXkUm8bpEd3VxDXEEv13nxVeQof+JGMdlkldFGaBRDU -oYQsPj00S3/GA9WDapwe81Wl2A== ------END PRIVATE KEY----- -""" + -----BEGIN PRIVATE KEY----- + MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDiC+TGmbSP/nWW + N1tjyNfnWCU5ATjtIOfdtP6ycx8JSeqkvyNXG21kNUn14jTTU8BglGL2hfVpCbMi + sUdbd3LpP8unSsvlOWwORFOViSy4YljSNM/FNoMtavuITA/sEELYgjWkz2o/uHPZ + Hud9+JQwGJgqIlMa3mr2IaaUZlWN3D1u88bzJYhpt3YyxRy9+OEoOKy36KdWwhKz + V3S8kXb0Y1GbAo68jJ9RfzeLy290mIs9qG2y1CNXWO6sxf6B//LaalizZiCfzYAV + KcNR9oNYsEJc5KB/+DsAGTzR7mL+oiU4h/vwVb2GTDat5C+PFGi6j1ujxYTRPO53 + 8ljgdslnAgMBAAECggEBANZNWFNAnYJ2R5xmVuo/GxFk68Ujd4i4TZpPYbhkk+QG + g8I0w5htlEQQkVHfZx2CpTvq8feuAH/YhlA5qeD5WaPwq26q5qsmyV6tQGDgb9lO + w85l6ySZDbwdVOJe2il/MSB6MclSKvTGNm59chJnfHYsmvY3HHq4qsc2F+tRKYMW + pY75LgEbaTUV69J3cbC1wAeVjv0q/krND+YkhYpTxNZhbazK/FHOCvY+zFu9fg0L + zpwbn5fb6wIvqG7tXp7koa3QMn64AXmO/fb5mBd8G2vBGYnxwb7Egwdg/3Dw+BXu + ynQLP7ixWsE2KNfR9Ce1i3YvEo6QDTv2340I3dntxkECgYEA9vdaL4PGyvEbpim4 + kqz1vuug8Iq0nTVDo6jmgH1o+XdcIbW3imXtgi5zUJpj4oDD7/4aufiJZjG64i/v + phe11xeUvh5QNNOzeMymVDoJut97F97KKKTv7bG8Rpon/WzH2I0SoAkECCwmdWAJ + H3nvOCnXEkpbCqmIUvHVURPRDn8CgYEA6lCk3EzFQlbXs3Sj5op61R3Mscx7/35A + eGv5axzbENHt1so+s3Zvyyi1bo4VBcwnKVCvQjmTuLiqrc9VfX8XdbiTUNnEr2u3 + 992Ja6DEJTZ9gy5WiviwYnwU2HpjwOVNBb17T0NLoRHkDZ6iXj7NZgwizOki5p3j + /hS0pObSIRkCgYEAiEdOGNIarHoHy9VR6H5QzR2xHYssx2NRA8p8B4MsnhxjVqaz + tUcxnJiNQXkwjRiJBrGthdnD2ASxH4dcMsb6rMpyZcbMc5ouewZS8j9khx4zCqUB + 4RPC4eMmBb+jOZEBZlnSYUUYWHokbrij0B61BsTvzUQCoQuUElEoaSkKP3kCgYEA + mwdqXHvK076jjo9w1drvtEu4IDc8H2oH++TsrEr2QiWzaDZ9z71f8BnqGNCW5jQS + AQrqOjXgIArGmqMgXB0Xh4LsrUS4Fpx9ptiD0JsYy8pGtuGUzvQFt9OC80ve7kSI + dnDMwj+zLUmqCrzXjuWcfpUu/UaPGeiDbZuDfcteYhkCgYBLyL5JY7Qd4gVQIhFX + 7Sv3sNJN3KZCQHEzut7IwojaxgpuxiFvgsoXXuYolVCQp32oWbYcE2Yke+hOKsTE + sCMAWZiSGN2Nrfea730IYAXkUm8bpEd3VxDXEEv13nxVeQof+JGMdlkldFGaBRDU + oYQsPj00S3/GA9WDapwe81Wl2A== + -----END PRIVATE KEY----- + """ + +final class BasicInboundTrafficShapingHandler: ChannelDuplexHandler { + typealias OutboundIn = ByteBuffer + typealias InboundIn = ByteBuffer + typealias OutboundOut = ByteBuffer + + enum ReadState { + case flowingFreely + case pausing + case paused + + mutating func pause() { + switch self { + case .flowingFreely: + self = .pausing + case .pausing, .paused: + () // nothing to do + } + } + + mutating func unpause() -> Bool { + switch self { + case .flowingFreely: + return false // no extra `read` needed + case .pausing: + self = .flowingFreely + return false // no extra `read` needed + case .paused: + self = .flowingFreely + return true // yes, we need an extra read + } + } + + mutating func shouldRead() -> Bool { + switch self { + case .flowingFreely: + return true + case .pausing: + self = .paused + return false + case .paused: + return false + } + } + } + + private let targetBytesPerSecond: Int + private var currentSecondBytesSeen: Int = 0 + private var readState: ReadState = .flowingFreely + + init(targetBytesPerSecond: Int) { + self.targetBytesPerSecond = targetBytesPerSecond + } + + func evaluatePause(context: ChannelHandlerContext) { + if self.currentSecondBytesSeen >= self.targetBytesPerSecond { + self.readState.pause() + } else if self.currentSecondBytesSeen < self.targetBytesPerSecond { + if self.readState.unpause() { + context.read() + } + } + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let loopBoundContext = NIOLoopBound(context, eventLoop: context.eventLoop) + defer { + context.fireChannelRead(data) + } + let buffer = Self.unwrapInboundIn(data) + let byteCount = buffer.readableBytes + self.currentSecondBytesSeen += byteCount + context.eventLoop.assumeIsolated().scheduleTask(in: .seconds(1)) { + self.currentSecondBytesSeen -= byteCount + self.evaluatePause(context: loopBoundContext.value) + } + self.evaluatePause(context: context) + } + + func read(context: ChannelHandlerContext) { + if self.readState.shouldRead() { + context.read() + } + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift deleted file mode 100644 index 7eb532cf9..000000000 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift +++ /dev/null @@ -1,141 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTPClientTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTPClientTests { - static var allTests: [(String, (HTTPClientTests) -> () throws -> Void)] { - return [ - ("testRequestURI", testRequestURI), - ("testBadRequestURI", testBadRequestURI), - ("testSchemaCasing", testSchemaCasing), - ("testURLSocketPathInitializers", testURLSocketPathInitializers), - ("testBadUnixWithBaseURL", testBadUnixWithBaseURL), - ("testConvenienceExecuteMethods", testConvenienceExecuteMethods), - ("testConvenienceExecuteMethodsOverSocket", testConvenienceExecuteMethodsOverSocket), - ("testConvenienceExecuteMethodsOverSecureSocket", testConvenienceExecuteMethodsOverSecureSocket), - ("testGet", testGet), - ("testGetWithDifferentEventLoopBackpressure", testGetWithDifferentEventLoopBackpressure), - ("testPost", testPost), - ("testPostWithGenericBody", testPostWithGenericBody), - ("testPostWithFoundationDataBody", testPostWithFoundationDataBody), - ("testGetHttps", testGetHttps), - ("testGetHttpsWithIP", testGetHttpsWithIP), - ("testGetHTTPSWorksOnMTELGWithIP", testGetHTTPSWorksOnMTELGWithIP), - ("testGetHttpsWithIPv6", testGetHttpsWithIPv6), - ("testGetHTTPSWorksOnMTELGWithIPv6", testGetHTTPSWorksOnMTELGWithIPv6), - ("testPostHttps", testPostHttps), - ("testHttpRedirect", testHttpRedirect), - ("testHttpHostRedirect", testHttpHostRedirect), - ("testPercentEncoded", testPercentEncoded), - ("testPercentEncodedBackslash", testPercentEncodedBackslash), - ("testMultipleContentLengthHeaders", testMultipleContentLengthHeaders), - ("testStreaming", testStreaming), - ("testFileDownload", testFileDownload), - ("testFileDownloadError", testFileDownloadError), - ("testRemoteClose", testRemoteClose), - ("testReadTimeout", testReadTimeout), - ("testConnectTimeout", testConnectTimeout), - ("testDeadline", testDeadline), - ("testCancel", testCancel), - ("testStressCancel", testStressCancel), - ("testHTTPClientAuthorization", testHTTPClientAuthorization), - ("testProxyPlaintext", testProxyPlaintext), - ("testProxyTLS", testProxyTLS), - ("testProxyPlaintextWithCorrectlyAuthorization", testProxyPlaintextWithCorrectlyAuthorization), - ("testProxyPlaintextWithIncorrectlyAuthorization", testProxyPlaintextWithIncorrectlyAuthorization), - ("testUploadStreaming", testUploadStreaming), - ("testEventLoopArgument", testEventLoopArgument), - ("testDecompression", testDecompression), - ("testDecompressionLimit", testDecompressionLimit), - ("testLoopDetectionRedirectLimit", testLoopDetectionRedirectLimit), - ("testCountRedirectLimit", testCountRedirectLimit), - ("testRedirectToTheInitialURLDoesThrowOnFirstRedirect", testRedirectToTheInitialURLDoesThrowOnFirstRedirect), - ("testMultipleConcurrentRequests", testMultipleConcurrentRequests), - ("testWorksWith500Error", testWorksWith500Error), - ("testWorksWithHTTP10Response", testWorksWithHTTP10Response), - ("testWorksWhenServerClosesConnectionAfterReceivingRequest", testWorksWhenServerClosesConnectionAfterReceivingRequest), - ("testSubsequentRequestsWorkWithServerSendingConnectionClose", testSubsequentRequestsWorkWithServerSendingConnectionClose), - ("testSubsequentRequestsWorkWithServerAlternatingBetweenKeepAliveAndClose", testSubsequentRequestsWorkWithServerAlternatingBetweenKeepAliveAndClose), - ("testStressGetHttps", testStressGetHttps), - ("testStressGetHttpsSSLError", testStressGetHttpsSSLError), - ("testFailingConnectionIsReleased", testFailingConnectionIsReleased), - ("testResponseDelayGet", testResponseDelayGet), - ("testIdleTimeoutNoReuse", testIdleTimeoutNoReuse), - ("testStressGetClose", testStressGetClose), - ("testManyConcurrentRequestsWork", testManyConcurrentRequestsWork), - ("testRepeatedRequestsWorkWhenServerAlwaysCloses", testRepeatedRequestsWorkWhenServerAlwaysCloses), - ("testShutdownBeforeTasksCompletion", testShutdownBeforeTasksCompletion), - ("testUncleanShutdownActuallyShutsDown", testUncleanShutdownActuallyShutsDown), - ("testUncleanShutdownCancelsTasks", testUncleanShutdownCancelsTasks), - ("testDoubleShutdown", testDoubleShutdown), - ("testTaskFailsWhenClientIsShutdown", testTaskFailsWhenClientIsShutdown), - ("testRaceNewRequestsVsShutdown", testRaceNewRequestsVsShutdown), - ("testVaryingLoopPreference", testVaryingLoopPreference), - ("testMakeSecondRequestDuringCancelledCallout", testMakeSecondRequestDuringCancelledCallout), - ("testMakeSecondRequestDuringSuccessCallout", testMakeSecondRequestDuringSuccessCallout), - ("testMakeSecondRequestWhilstFirstIsOngoing", testMakeSecondRequestWhilstFirstIsOngoing), - ("testUDSBasic", testUDSBasic), - ("testUDSSocketAndPath", testUDSSocketAndPath), - ("testHTTPPlusUNIX", testHTTPPlusUNIX), - ("testHTTPSPlusUNIX", testHTTPSPlusUNIX), - ("testUseExistingConnectionOnDifferentEL", testUseExistingConnectionOnDifferentEL), - ("testWeRecoverFromServerThatClosesTheConnectionOnUs", testWeRecoverFromServerThatClosesTheConnectionOnUs), - ("testPoolClosesIdleConnections", testPoolClosesIdleConnections), - ("testRacePoolIdleConnectionsAndGet", testRacePoolIdleConnectionsAndGet), - ("testAvoidLeakingTLSHandshakeCompletionPromise", testAvoidLeakingTLSHandshakeCompletionPromise), - ("testAsyncShutdown", testAsyncShutdown), - ("testAsyncShutdownDefaultQueue", testAsyncShutdownDefaultQueue), - ("testValidationErrorsAreSurfaced", testValidationErrorsAreSurfaced), - ("testUploadsReallyStream", testUploadsReallyStream), - ("testUploadStreamingCallinToleratedFromOtsideEL", testUploadStreamingCallinToleratedFromOtsideEL), - ("testWeHandleUsSendingACloseHeaderCorrectly", testWeHandleUsSendingACloseHeaderCorrectly), - ("testWeHandleUsReceivingACloseHeaderCorrectly", testWeHandleUsReceivingACloseHeaderCorrectly), - ("testWeHandleUsSendingACloseHeaderAmongstOtherConnectionHeadersCorrectly", testWeHandleUsSendingACloseHeaderAmongstOtherConnectionHeadersCorrectly), - ("testWeHandleUsReceivingACloseHeaderAmongstOtherConnectionHeadersCorrectly", testWeHandleUsReceivingACloseHeaderAmongstOtherConnectionHeadersCorrectly), - ("testLoggingCorrectlyAttachesRequestInformationEvenAfterDuringRedirect", testLoggingCorrectlyAttachesRequestInformationEvenAfterDuringRedirect), - ("testLoggingCorrectlyAttachesRequestInformation", testLoggingCorrectlyAttachesRequestInformation), - ("testNothingIsLoggedAtInfoOrHigher", testNothingIsLoggedAtInfoOrHigher), - ("testAllMethodsLog", testAllMethodsLog), - ("testClosingIdleConnectionsInPoolLogsInTheBackground", testClosingIdleConnectionsInPoolLogsInTheBackground), - ("testUploadStreamingNoLength", testUploadStreamingNoLength), - ("testConnectErrorPropagatedToDelegate", testConnectErrorPropagatedToDelegate), - ("testDelegateCallinsTolerateRandomEL", testDelegateCallinsTolerateRandomEL), - ("testContentLengthTooLongFails", testContentLengthTooLongFails), - ("testContentLengthTooShortFails", testContentLengthTooShortFails), - ("testBodyUploadAfterEndFails", testBodyUploadAfterEndFails), - ("testNoBytesSentOverBodyLimit", testNoBytesSentOverBodyLimit), - ("testDoubleError", testDoubleError), - ("testSSLHandshakeErrorPropagation", testSSLHandshakeErrorPropagation), - ("testSSLHandshakeErrorPropagationDelayedClose", testSSLHandshakeErrorPropagationDelayedClose), - ("testWeCloseConnectionsWhenConnectionCloseSetByServer", testWeCloseConnectionsWhenConnectionCloseSetByServer), - ("testBiDirectionalStreaming", testBiDirectionalStreaming), - ("testSynchronousHandshakeErrorReporting", testSynchronousHandshakeErrorReporting), - ("testFileDownloadChunked", testFileDownloadChunked), - ("testCloseWhileBackpressureIsExertedIsFine", testCloseWhileBackpressureIsExertedIsFine), - ("testErrorAfterCloseWhileBackpressureExerted", testErrorAfterCloseWhileBackpressureExerted), - ("testRequestSpecificTLS", testRequestSpecificTLS), - ("testConnectionPoolSizeConfigValueIsRespected", testConnectionPoolSizeConfigValueIsRespected), - ("testRequestWithHeaderTransferEncodingIdentityDoesNotFail", testRequestWithHeaderTransferEncodingIdentityDoesNotFail), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index 6bb4dd9b4..054cf3487 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -12,15 +12,16 @@ // //===----------------------------------------------------------------------===// -/* NOT @testable */ import AsyncHTTPClient // Tests that need @testable go into HTTPClientInternalTests.swift -#if canImport(Network) -import Network -#endif +import AsyncHTTPClient // NOT @testable - tests that need @testable go into HTTPClientInternalTests.swift +import Atomics +import InMemoryLogging import Logging import NIOConcurrencyHelpers import NIOCore +import NIOEmbedded import NIOFoundationCompat import NIOHTTP1 +import NIOHTTP2 import NIOHTTPCompression import NIOPosix import NIOSSL @@ -28,60 +29,11 @@ import NIOTestUtils import NIOTransportServices import XCTest -class HTTPClientTests: XCTestCase { - typealias Request = HTTPClient.Request - - var clientGroup: EventLoopGroup! - var serverGroup: EventLoopGroup! - var defaultHTTPBin: HTTPBin! - var defaultClient: HTTPClient! - var backgroundLogStore: CollectEverythingLogHandler.LogStore! - - var defaultHTTPBinURLPrefix: String { - return "http://localhost:\(self.defaultHTTPBin.port)/" - } - - override func setUp() { - XCTAssertNil(self.clientGroup) - XCTAssertNil(self.serverGroup) - XCTAssertNil(self.defaultHTTPBin) - XCTAssertNil(self.defaultClient) - XCTAssertNil(self.backgroundLogStore) - - self.clientGroup = getDefaultEventLoopGroup(numberOfThreads: 1) - self.serverGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) - self.defaultHTTPBin = HTTPBin() - self.backgroundLogStore = CollectEverythingLogHandler.LogStore() - var backgroundLogger = Logger(label: "\(#function)", factory: { _ in - CollectEverythingLogHandler(logStore: self.backgroundLogStore!) - }) - backgroundLogger.logLevel = .trace - self.defaultClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - backgroundActivityLogger: backgroundLogger) - } - - override func tearDown() { - if let defaultClient = self.defaultClient { - XCTAssertNoThrow(try defaultClient.syncShutdown()) - self.defaultClient = nil - } - - XCTAssertNotNil(self.defaultHTTPBin) - XCTAssertNoThrow(try self.defaultHTTPBin.shutdown()) - self.defaultHTTPBin = nil - - XCTAssertNotNil(self.clientGroup) - XCTAssertNoThrow(try self.clientGroup.syncShutdownGracefully()) - self.clientGroup = nil - - XCTAssertNotNil(self.serverGroup) - XCTAssertNoThrow(try self.serverGroup.syncShutdownGracefully()) - self.serverGroup = nil - - XCTAssertNotNil(self.backgroundLogStore) - self.backgroundLogStore = nil - } +#if canImport(Network) +import Network +#endif +final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { func testRequestURI() throws { let request1 = try Request(url: "https://someserver.com:8888/some/path?foo=bar") XCTAssertEqual(request1.url.host, "someserver.com") @@ -94,8 +46,12 @@ class HTTPClientTests: XCTestCase { XCTAssertEqual(request2.url.path, "") let request3 = try Request(url: "unix:///tmp/file") - XCTAssertNil(request3.url.host) XCTAssertEqual(request3.host, "") + #if os(Linux) && compiler(<6.1) + XCTAssertEqual(request3.url.host, "") + #else + XCTAssertNil(request3.url.host) + #endif XCTAssertEqual(request3.url.path, "/tmp/file") XCTAssertEqual(request3.port, 80) XCTAssertFalse(request3.useTLS) @@ -169,7 +125,10 @@ class HTTPClientTests: XCTestCase { XCTAssertEqual(url.scheme, "http+unix") XCTAssertEqual(url.host, "/tmp/file with spacesと漢字") XCTAssertEqual(url.path, "/file/path") - XCTAssertEqual(url.absoluteString, "http+unix://%2Ftmp%2Ffile%20with%20spaces%E3%81%A8%E6%BC%A2%E5%AD%97/file/path") + XCTAssertEqual( + url.absoluteString, + "http+unix://%2Ftmp%2Ffile%20with%20spaces%E3%81%A8%E6%BC%A2%E5%AD%97/file/path" + ) } let url5 = URL(httpsURLWithSocketPath: "/tmp/file") @@ -205,14 +164,11 @@ class HTTPClientTests: XCTestCase { XCTAssertEqual(url.scheme, "https+unix") XCTAssertEqual(url.host, "/tmp/file with spacesと漢字") XCTAssertEqual(url.path, "/file/path") - XCTAssertEqual(url.absoluteString, "https+unix://%2Ftmp%2Ffile%20with%20spaces%E3%81%A8%E6%BC%A2%E5%AD%97/file/path") + XCTAssertEqual( + url.absoluteString, + "https+unix://%2Ftmp%2Ffile%20with%20spaces%E3%81%A8%E6%BC%A2%E5%AD%97/file/path" + ) } - - let url9 = URL(httpURLWithSocketPath: "/tmp/file", uri: " ") - XCTAssertNil(url9) - - let url10 = URL(httpsURLWithSocketPath: "/tmp/file", uri: " ") - XCTAssertNil(url10) } func testBadUnixWithBaseURL() { @@ -224,55 +180,116 @@ class HTTPClientTests: XCTestCase { } func testConvenienceExecuteMethods() throws { - XCTAssertEqual(["GET"[...]], - try self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - XCTAssertEqual(["POST"[...]], - try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - XCTAssertEqual(["PATCH"[...]], - try self.defaultClient.patch(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - XCTAssertEqual(["PUT"[...]], - try self.defaultClient.put(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - XCTAssertEqual(["DELETE"[...]], - try self.defaultClient.delete(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - XCTAssertEqual(["GET"[...]], - try self.defaultClient.execute(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - XCTAssertEqual(["CHECKOUT"[...]], - try self.defaultClient.execute(.CHECKOUT, url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) + XCTAssertEqual( + ["GET"[...]], + try self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + XCTAssertEqual( + ["POST"[...]], + try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + XCTAssertEqual( + ["PATCH"[...]], + try self.defaultClient.patch(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + XCTAssertEqual( + ["PUT"[...]], + try self.defaultClient.put(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + XCTAssertEqual( + ["DELETE"[...]], + try self.defaultClient.delete(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + XCTAssertEqual( + ["GET"[...]], + try self.defaultClient.execute(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + XCTAssertEqual( + ["CHECKOUT"[...]], + try self.defaultClient.execute(.CHECKOUT, url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) } func testConvenienceExecuteMethodsOverSocket() throws { - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let localSocketPathHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) - defer { - XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) - } + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let localSocketPathHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) + defer { + XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) + } - XCTAssertEqual(["GET"[...]], - try self.defaultClient.execute(socketPath: path, urlPath: "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - XCTAssertEqual(["GET"[...]], - try self.defaultClient.execute(.GET, socketPath: path, urlPath: "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - XCTAssertEqual(["POST"[...]], - try self.defaultClient.execute(.POST, socketPath: path, urlPath: "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - }) + XCTAssertEqual( + ["GET"[...]], + try self.defaultClient.execute(socketPath: path, urlPath: "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + XCTAssertEqual( + ["GET"[...]], + try self.defaultClient.execute(.GET, socketPath: path, urlPath: "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + XCTAssertEqual( + ["POST"[...]], + try self.defaultClient.execute(.POST, socketPath: path, urlPath: "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + } + ) } func testConvenienceExecuteMethodsOverSecureSocket() throws { - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let localSocketPathHTTPBin = HTTPBin(.http1_1(ssl: true, compress: false), bindTarget: .unixDomainSocket(path)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none)) - defer { - XCTAssertNoThrow(try localClient.syncShutdown()) - XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) - } + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let localSocketPathHTTPBin = HTTPBin( + .http1_1(ssl: true, compress: false), + bindTarget: .unixDomainSocket(path) + ) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(certificateVerification: .none) + ) + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) + } - XCTAssertEqual(["GET"[...]], - try localClient.execute(secureSocketPath: path, urlPath: "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - XCTAssertEqual(["GET"[...]], - try localClient.execute(.GET, secureSocketPath: path, urlPath: "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - XCTAssertEqual(["POST"[...]], - try localClient.execute(.POST, secureSocketPath: path, urlPath: "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - }) + XCTAssertEqual( + ["GET"[...]], + try localClient.execute(secureSocketPath: path, urlPath: "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + XCTAssertEqual( + ["GET"[...]], + try localClient.execute(.GET, secureSocketPath: path, urlPath: "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + XCTAssertEqual( + ["POST"[...]], + try localClient.execute(.POST, secureSocketPath: path, urlPath: "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + } + ) } func testGet() throws { @@ -288,7 +305,8 @@ class HTTPClientTests: XCTestCase { } func testPost() throws { - let response = try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + "post", body: .string("1234")).wait() + let response = try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + "post", body: .string("1234")) + .wait() let bytes = response.body.flatMap { $0.getData(at: 0, length: $0.readableBytes) } let data = try JSONDecoder().decode(RequestInfo.self, from: bytes!) @@ -297,10 +315,10 @@ class HTTPClientTests: XCTestCase { } func testPostWithGenericBody() throws { - let bodyData = Array("hello, world!").lazy.map { $0.uppercased().first!.asciiValue! } - let erasedData = AnyRandomAccessCollection(bodyData) + let bodyData = Array(Array("hello, world!").lazy.map { $0.uppercased().first!.asciiValue! }) - let response = try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + "post", body: .bytes(erasedData)).wait() + let response = try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + "post", body: .bytes(bodyData)) + .wait() let bytes = response.body.flatMap { $0.getData(at: 0, length: $0.readableBytes) } let data = try JSONDecoder().decode(RequestInfo.self, from: bytes!) @@ -311,7 +329,8 @@ class HTTPClientTests: XCTestCase { func testPostWithFoundationDataBody() throws { let bodyData = Data("hello, world!".utf8) - let response = try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + "post", body: .data(bodyData)).wait() + let response = try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + "post", body: .data(bodyData)) + .wait() let bytes = response.body.flatMap { $0.getData(at: 0, length: $0.readableBytes) } let data = try JSONDecoder().decode(RequestInfo.self, from: bytes!) @@ -321,8 +340,10 @@ class HTTPClientTests: XCTestCase { func testGetHttps() throws { let localHTTPBin = HTTPBin(.http1_1(ssl: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(certificateVerification: .none) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) XCTAssertNoThrow(try localHTTPBin.shutdown()) @@ -334,8 +355,10 @@ class HTTPClientTests: XCTestCase { func testGetHttpsWithIP() throws { let localHTTPBin = HTTPBin(.http1_1(ssl: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(certificateVerification: .none) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) XCTAssertNoThrow(try localHTTPBin.shutdown()) @@ -353,8 +376,10 @@ class HTTPClientTests: XCTestCase { XCTAssertNoThrow(try group.syncShutdownGracefully()) } let localHTTPBin = HTTPBin(.http1_1(ssl: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(group), - configuration: HTTPClient.Configuration(certificateVerification: .none)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(group), + configuration: HTTPClient.Configuration(certificateVerification: .none) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) XCTAssertNoThrow(try localHTTPBin.shutdown()) @@ -367,8 +392,10 @@ class HTTPClientTests: XCTestCase { func testGetHttpsWithIPv6() throws { try XCTSkipUnless(canBindIPv6Loopback, "Requires IPv6") let localHTTPBin = HTTPBin(.http1_1(ssl: true), bindTarget: .localhostIPv6RandomPort) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(certificateVerification: .none) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) XCTAssertNoThrow(try localHTTPBin.shutdown()) @@ -387,8 +414,10 @@ class HTTPClientTests: XCTestCase { XCTAssertNoThrow(try group.syncShutdownGracefully()) } let localHTTPBin = HTTPBin(.http1_1(ssl: true), bindTarget: .localhostIPv6RandomPort) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(group), - configuration: HTTPClient.Configuration(certificateVerification: .none)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(group), + configuration: HTTPClient.Configuration(certificateVerification: .none) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) XCTAssertNoThrow(try localHTTPBin.shutdown()) @@ -400,14 +429,20 @@ class HTTPClientTests: XCTestCase { func testPostHttps() throws { let localHTTPBin = HTTPBin(.http1_1(ssl: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(certificateVerification: .none) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) XCTAssertNoThrow(try localHTTPBin.shutdown()) } - let request = try Request(url: "https://localhost:\(localHTTPBin.port)/post", method: .POST, body: .string("1234")) + let request = try Request( + url: "https://localhost:\(localHTTPBin.port)/post", + method: .POST, + body: .string("1234") + ) let response = try localClient.execute(request: request).wait() let bytes = response.body.flatMap { $0.getData(at: 0, length: $0.readableBytes) } @@ -419,8 +454,13 @@ class HTTPClientTests: XCTestCase { func testHttpRedirect() throws { let httpsBin = HTTPBin(.http1_1(ssl: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none, redirectConfiguration: .follow(max: 10, allowCycles: true))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration( + certificateVerification: .none, + redirectConfiguration: .follow(max: 10, allowCycles: true) + ) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) @@ -429,103 +469,246 @@ class HTTPClientTests: XCTestCase { var response = try localClient.get(url: self.defaultHTTPBinURLPrefix + "redirect/302").wait() XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.url?.absoluteString, self.defaultHTTPBinURLPrefix + "ok") + XCTAssertEqual( + response.history.map(\.request.url.absoluteString), + [ + self.defaultHTTPBinURLPrefix + "redirect/302", + self.defaultHTTPBinURLPrefix + "ok", + ] + ) - response = try localClient.get(url: self.defaultHTTPBinURLPrefix + "redirect/https?port=\(httpsBin.port)").wait() + response = try localClient.get(url: self.defaultHTTPBinURLPrefix + "redirect/https?port=\(httpsBin.port)") + .wait() XCTAssertEqual(response.status, .ok) - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { httpSocketPath in - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { httpsSocketPath in - let socketHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(httpSocketPath)) - let socketHTTPSBin = HTTPBin(.http1_1(ssl: true), bindTarget: .unixDomainSocket(httpsSocketPath)) - defer { - XCTAssertNoThrow(try socketHTTPBin.shutdown()) - XCTAssertNoThrow(try socketHTTPSBin.shutdown()) - } - - // From HTTP or HTTPS to HTTP+UNIX should fail to redirect - var targetURL = "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" - var request = try Request(url: self.defaultHTTPBinURLPrefix + "redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - var response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .found) - XCTAssertEqual(response.headers.first(name: "Location"), targetURL) - - request = try Request(url: "https://localhost:\(httpsBin.port)/redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .found) - XCTAssertEqual(response.headers.first(name: "Location"), targetURL) - - // From HTTP or HTTPS to HTTPS+UNIX should also fail to redirect - targetURL = "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" - request = try Request(url: self.defaultHTTPBinURLPrefix + "redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .found) - XCTAssertEqual(response.headers.first(name: "Location"), targetURL) - - request = try Request(url: "https://localhost:\(httpsBin.port)/redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .found) - XCTAssertEqual(response.headers.first(name: "Location"), targetURL) - - // ... while HTTP+UNIX to HTTP, HTTPS, or HTTP(S)+UNIX should succeed - targetURL = self.defaultHTTPBinURLPrefix + "ok" - request = try Request(url: "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .ok) - - targetURL = "https://localhost:\(httpsBin.port)/ok" - request = try Request(url: "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .ok) - - targetURL = "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" - request = try Request(url: "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .ok) - - targetURL = "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" - request = try Request(url: "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .ok) - - // ... and HTTPS+UNIX to HTTP, HTTPS, or HTTP(S)+UNIX should succeed - targetURL = self.defaultHTTPBinURLPrefix + "ok" - request = try Request(url: "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .ok) - - targetURL = "https://localhost:\(httpsBin.port)/ok" - request = try Request(url: "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .ok) - - targetURL = "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" - request = try Request(url: "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .ok) - - targetURL = "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" - request = try Request(url: "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { httpSocketPath in + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { httpsSocketPath in + let socketHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(httpSocketPath)) + let socketHTTPSBin = HTTPBin( + .http1_1(ssl: true), + bindTarget: .unixDomainSocket(httpsSocketPath) + ) + defer { + XCTAssertNoThrow(try socketHTTPBin.shutdown()) + XCTAssertNoThrow(try socketHTTPSBin.shutdown()) + } - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .ok) - }) - }) + // From HTTP or HTTPS to HTTP+UNIX should fail to redirect + var targetURL = + "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" + var request = try Request( + url: self.defaultHTTPBinURLPrefix + "redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + var response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .found) + XCTAssertEqual(response.headers.first(name: "Location"), targetURL) + XCTAssertEqual(response.url, request.url) + XCTAssertEqual(response.history.map(\.request.url), [request.url]) + + request = try Request( + url: "https://localhost:\(httpsBin.port)/redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .found) + XCTAssertEqual(response.headers.first(name: "Location"), targetURL) + XCTAssertEqual(response.url, request.url) + XCTAssertEqual(response.history.map(\.request.url), [request.url]) + + // From HTTP or HTTPS to HTTPS+UNIX should also fail to redirect + targetURL = + "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" + request = try Request( + url: self.defaultHTTPBinURLPrefix + "redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .found) + XCTAssertEqual(response.headers.first(name: "Location"), targetURL) + XCTAssertEqual(response.url, request.url) + XCTAssertEqual(response.history.map(\.request.url), [request.url]) + + request = try Request( + url: "https://localhost:\(httpsBin.port)/redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .found) + XCTAssertEqual(response.headers.first(name: "Location"), targetURL) + XCTAssertEqual(response.url, request.url) + XCTAssertEqual(response.history.map(\.request.url), [request.url]) + + // ... while HTTP+UNIX to HTTP, HTTPS, or HTTP(S)+UNIX should succeed + targetURL = self.defaultHTTPBinURLPrefix + "ok" + request = try Request( + url: + "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.url?.absoluteString, targetURL) + XCTAssertEqual( + response.history.map(\.request.url.absoluteString), + [request.url.absoluteString, targetURL] + ) + + targetURL = "https://localhost:\(httpsBin.port)/ok" + request = try Request( + url: + "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.url?.absoluteString, targetURL) + XCTAssertEqual( + response.history.map(\.request.url.absoluteString), + [request.url.absoluteString, targetURL] + ) + + targetURL = + "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" + request = try Request( + url: + "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.url?.absoluteString, targetURL) + XCTAssertEqual( + response.history.map(\.request.url.absoluteString), + [request.url.absoluteString, targetURL] + ) + + targetURL = + "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" + request = try Request( + url: + "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.url?.absoluteString, targetURL) + XCTAssertEqual( + response.history.map(\.request.url.absoluteString), + [request.url.absoluteString, targetURL] + ) + + // ... and HTTPS+UNIX to HTTP, HTTPS, or HTTP(S)+UNIX should succeed + targetURL = self.defaultHTTPBinURLPrefix + "ok" + request = try Request( + url: + "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.url?.absoluteString, targetURL) + XCTAssertEqual( + response.history.map(\.request.url.absoluteString), + [request.url.absoluteString, targetURL] + ) + + targetURL = "https://localhost:\(httpsBin.port)/ok" + request = try Request( + url: + "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.url?.absoluteString, targetURL) + XCTAssertEqual( + response.history.map(\.request.url.absoluteString), + [request.url.absoluteString, targetURL] + ) + + targetURL = + "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" + request = try Request( + url: + "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.url?.absoluteString, targetURL) + XCTAssertEqual( + response.history.map(\.request.url.absoluteString), + [request.url.absoluteString, targetURL] + ) + + targetURL = + "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" + request = try Request( + url: + "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.url?.absoluteString, targetURL) + XCTAssertEqual( + response.history.map(\.request.url.absoluteString), + [request.url.absoluteString, targetURL] + ) + } + ) + } + ) } func testHttpHostRedirect() { - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none, redirectConfiguration: .follow(max: 10, allowCycles: true))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration( + certificateVerification: .none, + redirectConfiguration: .follow(max: 10, allowCycles: true) + ) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) @@ -552,12 +735,37 @@ class HTTPClientTests: XCTestCase { XCTAssertEqual(.ok, response.status) } + func testLeadingSlashRelativeURL() throws { + let noLeadingSlashURL = URL( + string: "percent%2Fencoded/hello", + relativeTo: URL(string: self.defaultHTTPBinURLPrefix)! + )! + let withLeadingSlashURL = URL( + string: "/percent%2Fencoded/hello", + relativeTo: URL(string: self.defaultHTTPBinURLPrefix)! + )! + + let noLeadingSlashURLRequest = try HTTPClient.Request(url: noLeadingSlashURL, method: .GET) + let withLeadingSlashURLRequest = try HTTPClient.Request(url: withLeadingSlashURL, method: .GET) + + let noLeadingSlashURLResponse = try self.defaultClient.execute(request: noLeadingSlashURLRequest).wait() + let withLeadingSlashURLResponse = try self.defaultClient.execute(request: withLeadingSlashURLRequest).wait() + + XCTAssertEqual(noLeadingSlashURLResponse.status, .ok) + XCTAssertEqual(withLeadingSlashURLResponse.status, .ok) + } + func testMultipleContentLengthHeaders() throws { let body = ByteBuffer(string: "hello world!") var headers = HTTPHeaders() headers.add(name: "Content-Length", value: "12") - let request = try Request(url: self.defaultHTTPBinURLPrefix + "post", method: .POST, headers: headers, body: .byteBuffer(body)) + let request = try Request( + url: self.defaultHTTPBinURLPrefix + "post", + method: .POST, + headers: headers, + body: .byteBuffer(body) + ) let response = try self.defaultClient.execute(request: request).wait() // if the library adds another content length header we'll get a bad request error. XCTAssertEqual(.ok, response.status) @@ -577,11 +785,11 @@ class HTTPClientTests: XCTestCase { var request = try Request(url: self.defaultHTTPBinURLPrefix + "events/10/content-length") request.headers.add(name: "Accept", value: "text/event-stream") - let progress = - try TemporaryFileHelpers.withTemporaryFilePath { path -> FileDownloadDelegate.Progress in + let response = + try TemporaryFileHelpers.withTemporaryFilePath { path -> FileDownloadDelegate.Response in let delegate = try FileDownloadDelegate(path: path) - let progress = try self.defaultClient.execute( + let response = try self.defaultClient.execute( request: request, delegate: delegate ) @@ -589,24 +797,30 @@ class HTTPClientTests: XCTestCase { try XCTAssertEqual(50, TemporaryFileHelpers.fileSize(path: path)) - return progress + return response } - XCTAssertEqual(50, progress.totalBytes) - XCTAssertEqual(50, progress.receivedBytes) + XCTAssertEqual(.ok, response.head.status) + XCTAssertEqual("50", response.head.headers.first(name: "content-length")) + + XCTAssertEqual(50, response.totalBytes) + XCTAssertEqual(50, response.receivedBytes) } func testFileDownloadError() throws { var request = try Request(url: self.defaultHTTPBinURLPrefix + "not-found") request.headers.add(name: "Accept", value: "text/event-stream") - let progress = - try TemporaryFileHelpers.withTemporaryFilePath { path -> FileDownloadDelegate.Progress in - let delegate = try FileDownloadDelegate(path: path, reportHead: { - XCTAssertEqual($0.status, .notFound) - }) + let response = + try TemporaryFileHelpers.withTemporaryFilePath { path -> FileDownloadDelegate.Response in + let delegate = try FileDownloadDelegate( + path: path, + reportHead: { + XCTAssertEqual($0.status, .notFound) + } + ) - let progress = try self.defaultClient.execute( + let response = try self.defaultClient.execute( request: request, delegate: delegate ) @@ -614,11 +828,43 @@ class HTTPClientTests: XCTestCase { XCTAssertFalse(TemporaryFileHelpers.fileExists(path: path)) - return progress + return response + } + + XCTAssertEqual(.notFound, response.head.status) + XCTAssertFalse(response.head.headers.contains(name: "content-length")) + + XCTAssertEqual(nil, response.totalBytes) + XCTAssertEqual(0, response.receivedBytes) + } + + func testFileDownloadCustomError() throws { + let request = try Request(url: self.defaultHTTPBinURLPrefix + "get") + struct CustomError: Equatable, Error {} + + try TemporaryFileHelpers.withTemporaryFilePath { path in + let delegate = try FileDownloadDelegate( + path: path, + reportHead: { task, head in + XCTAssertEqual(head.status, .ok) + task.fail(reason: CustomError()) + }, + reportProgress: { _, _ in + XCTFail("should never be called") + } + ) + XCTAssertThrowsError( + try self.defaultClient.execute( + request: request, + delegate: delegate + ) + .wait() + ) { error in + XCTAssertEqualTypeAndValue(error, CustomError()) } - XCTAssertEqual(nil, progress.totalBytes) - XCTAssertEqual(0, progress.receivedBytes) + XCTAssertFalse(TemporaryFileHelpers.fileExists(path: path)) + } } func testRemoteClose() { @@ -628,8 +874,10 @@ class HTTPClientTests: XCTestCase { } func testReadTimeout() { - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(timeout: HTTPClient.Configuration.Timeout(read: .milliseconds(150)))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(timeout: HTTPClient.Configuration.Timeout(read: .milliseconds(150))) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) @@ -640,22 +888,89 @@ class HTTPClientTests: XCTestCase { } } - func testConnectTimeout() { - let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(timeout: .init(connect: .milliseconds(100), read: .milliseconds(150)))) + func testWriteTimeout() throws { + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(timeout: HTTPClient.Configuration.Timeout(write: .nanoseconds(10))) + ) + + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + } + + // Create a request that writes a chunk, then waits longer than the configured write timeout, + // and then writes again. This should trigger a write timeout error. + let request = try HTTPClient.Request( + url: self.defaultHTTPBinURLPrefix + "post", + method: .POST, + headers: ["transfer-encoding": "chunked"], + body: .stream { streamWriter in + _ = streamWriter.write(.byteBuffer(.init())) + + let promise = localClient.eventLoopGroup.next().makePromise(of: Void.self) + localClient.eventLoopGroup.next().scheduleTask(in: .milliseconds(3)) { + streamWriter.write(.byteBuffer(.init())).cascade(to: promise) + } + + return promise.futureResult + } + ) + + XCTAssertThrowsError(try localClient.execute(request: request).wait()) { + XCTAssertEqual($0 as? HTTPClientError, .writeTimeout) + } + } + + func testConnectTimeout() throws { + #if os(Linux) + // 198.51.100.254 is reserved for documentation only and therefore should not accept any TCP connection + let url = "http://198.51.100.254/get" + #else + // on macOS we can use the TCP backlog behaviour when the queue is full to simulate a non reachable server. + // this makes this test a bit more stable if `198.51.100.254` actually responds to connection attempt. + // The backlog behaviour on Linux can not be used to simulate a non-reachable server. + // Linux sends a `SYN/ACK` back even if the `backlog` queue is full as it has two queues. + // The second queue is not limit by `ChannelOptions.backlog` but by `/proc/sys/net/ipv4/tcp_max_syn_backlog`. + + let serverChannel = try ServerBootstrap(group: self.serverGroup) + .serverChannelOption(ChannelOptions.backlog, value: 1) + .serverChannelOption(ChannelOptions.autoRead, value: false) + .bind(host: "127.0.0.1", port: 0) + .wait() + defer { + XCTAssertNoThrow(try serverChannel.close().wait()) + } + let port = serverChannel.localAddress!.port! + let firstClientChannel = try ClientBootstrap(group: self.serverGroup) + .connect(host: "127.0.0.1", port: port) + .wait() + defer { + XCTAssertNoThrow(try firstClientChannel.close().wait()) + } + let url = "http://localhost:\(port)/get" + #endif + + let httpClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init(timeout: .init(connect: .milliseconds(100), read: .milliseconds(150))) + ) defer { XCTAssertNoThrow(try httpClient.syncShutdown()) } - // This must throw as 198.51.100.254 is reserved for documentation only - XCTAssertThrowsError(try httpClient.get(url: "http://198.51.100.254/get").wait()) { - XCTAssertEqual($0 as? HTTPClientError, .connectTimeout) + XCTAssertThrowsError(try httpClient.get(url: url).wait()) { + XCTAssertEqualTypeAndValue($0, HTTPClientError.connectTimeout) } } func testDeadline() { - XCTAssertThrowsError(try self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "wait", deadline: .now() + .milliseconds(150)).wait()) { + XCTAssertThrowsError( + try self.defaultClient.get( + url: self.defaultHTTPBinURLPrefix + "wait", + deadline: .now() + .milliseconds(150) + ).wait() + ) { XCTAssertEqual($0 as? HTTPClientError, .deadlineExceeded) } } @@ -741,7 +1056,13 @@ class HTTPClientTests: XCTestCase { let localHTTPBin = HTTPBin(proxy: .simulate(authorization: "Basic YWxhZGRpbjpvcGVuc2VzYW1l")) let localClient = HTTPClient( eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(proxy: .server(host: "localhost", port: localHTTPBin.port, authorization: .basic(username: "aladdin", password: "opensesame"))) + configuration: .init( + proxy: .server( + host: "localhost", + port: localHTTPBin.port, + authorization: .basic(username: "aladdin", password: "opensesame") + ) + ) ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) @@ -753,11 +1074,19 @@ class HTTPClientTests: XCTestCase { func testProxyPlaintextWithIncorrectlyAuthorization() throws { let localHTTPBin = HTTPBin(proxy: .simulate(authorization: "Basic YWxhZGRpbjpvcGVuc2VzYW1l")) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(proxy: .server(host: "localhost", - port: localHTTPBin.port, - authorization: .basic(username: "aladdin", - password: "opensesamefoo")))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init( + proxy: .server( + host: "localhost", + port: localHTTPBin.port, + authorization: .basic( + username: "aladdin", + password: "opensesamefoo" + ) + ) + ).enableFastFailureModeForTesting() + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) XCTAssertNoThrow(try localHTTPBin.shutdown()) @@ -770,7 +1099,7 @@ class HTTPClientTests: XCTestCase { } func testUploadStreaming() throws { - let body: HTTPClient.Body = .stream(length: 8) { writer in + let body: HTTPClient.Body = .stream(contentLength: 8) { writer in let buffer = ByteBuffer(string: "1234") return writer.write(.byteBuffer(buffer)).flatMap { let buffer = ByteBuffer(string: "4321") @@ -787,48 +1116,57 @@ class HTTPClientTests: XCTestCase { } func testEventLoopArgument() throws { - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(redirectConfiguration: .follow(max: 10, allowCycles: true))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(redirectConfiguration: .follow(max: 10, allowCycles: true)) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) } - class EventLoopValidatingDelegate: HTTPClientResponseDelegate { + final class EventLoopValidatingDelegate: HTTPClientResponseDelegate { typealias Response = Bool let eventLoop: EventLoop - var result = false + let result = NIOLockedValueBox(false) init(eventLoop: EventLoop) { self.eventLoop = eventLoop } func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { - self.result = task.eventLoop === self.eventLoop + self.result.withLockedValue { $0 = task.eventLoop === self.eventLoop } return task.eventLoop.makeSucceededFuture(()) } func didFinishRequest(task: HTTPClient.Task) throws -> Bool { - return self.result + self.result.withLockedValue { $0 } } } let eventLoop = self.clientGroup.next() let delegate = EventLoopValidatingDelegate(eventLoop: eventLoop) var request = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get") - var response = try localClient.execute(request: request, delegate: delegate, eventLoop: .delegate(on: eventLoop)).wait() + var response = try localClient.execute( + request: request, + delegate: delegate, + eventLoop: .delegate(on: eventLoop) + ).wait() XCTAssertEqual(true, response) // redirect request = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "redirect/302") - response = try localClient.execute(request: request, delegate: delegate, eventLoop: .delegate(on: eventLoop)).wait() + response = try localClient.execute(request: request, delegate: delegate, eventLoop: .delegate(on: eventLoop)) + .wait() XCTAssertEqual(true, response) } func testDecompression() throws { let localHTTPBin = HTTPBin(.http1_1(compress: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(decompression: .enabled(limit: .none))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init(decompression: .enabled(limit: .none)) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) @@ -837,7 +1175,8 @@ class HTTPClientTests: XCTestCase { var body = "" for _ in 1...1000 { - body += "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." + body += + "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." } for algorithm in [nil, "gzip", "deflate"] { @@ -862,9 +1201,56 @@ class HTTPClientTests: XCTestCase { } } + func testDecompressionHTTP2() throws { + let localHTTPBin = HTTPBin(.http2(compress: true)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init( + certificateVerification: .none, + decompression: .enabled(limit: .none) + ) + ) + + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + XCTAssertNoThrow(try localHTTPBin.shutdown()) + } + + var body = "" + for _ in 1...1000 { + body += + "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." + } + + for algorithm: String? in [nil] { + var request = try HTTPClient.Request(url: "https://localhost:\(localHTTPBin.port)/post", method: .POST) + request.body = .string(body) + if let algorithm = algorithm { + request.headers.add(name: "Accept-Encoding", value: algorithm) + } + + let response = try localClient.execute(request: request).wait() + var responseBody = try XCTUnwrap(response.body) + let data = try responseBody.readJSONDecodable(RequestInfo.self, length: responseBody.readableBytes) + + XCTAssertEqual(.ok, response.status) + let contentLength = try XCTUnwrap(response.headers["Content-Length"].first.flatMap { Int($0) }) + XCTAssertGreaterThan(body.count, contentLength) + if let algorithm = algorithm { + XCTAssertEqual(algorithm, response.headers["Content-Encoding"].first) + } else { + XCTAssertEqual("deflate", response.headers["Content-Encoding"].first) + } + XCTAssertEqual(body, data?.data) + } + } + func testDecompressionLimit() throws { let localHTTPBin = HTTPBin(.http1_1(compress: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: .init(decompression: .enabled(limit: .ratio(1)))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init(decompression: .enabled(limit: .ratio(1))) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) @@ -882,30 +1268,47 @@ class HTTPClientTests: XCTestCase { func testLoopDetectionRedirectLimit() throws { let localHTTPBin = HTTPBin(.http1_1(ssl: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none, redirectConfiguration: .follow(max: 5, allowCycles: false))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration( + certificateVerification: .none, + redirectConfiguration: .follow(max: 5, allowCycles: false) + ) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) XCTAssertNoThrow(try localHTTPBin.shutdown()) } - XCTAssertThrowsError(try localClient.get(url: "https://localhost:\(localHTTPBin.port)/redirect/infinite1").wait(), "Should fail with redirect limit") { error in + XCTAssertThrowsError( + try localClient.get(url: "https://localhost:\(localHTTPBin.port)/redirect/infinite1").wait(), + "Should fail with redirect limit" + ) { error in XCTAssertEqual(error as? HTTPClientError, HTTPClientError.redirectCycleDetected) } } func testCountRedirectLimit() throws { let localHTTPBin = HTTPBin(.http1_1(ssl: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none, redirectConfiguration: .follow(max: 10, allowCycles: true))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration( + certificateVerification: .none, + redirectConfiguration: .follow(max: 10, allowCycles: true) + ) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) XCTAssertNoThrow(try localHTTPBin.shutdown()) } - XCTAssertThrowsError(try localClient.get(url: "https://localhost:\(localHTTPBin.port)/redirect/infinite1").timeout(after: .seconds(10)).wait()) { error in + XCTAssertThrowsError( + try localClient.get(url: "https://localhost:\(localHTTPBin.port)/redirect/infinite1").timeout( + after: .seconds(10) + ).wait() + ) { error in XCTAssertEqual(error as? HTTPClientError, HTTPClientError.redirectLimitReached) } } @@ -923,13 +1326,15 @@ class HTTPClientTests: XCTestCase { defer { XCTAssertNoThrow(try localClient.syncShutdown()) } var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request( - url: "https://localhost:\(localHTTPBin.port)/redirect/target", - method: .GET, - headers: [ - "X-Target-Redirect-URL": "/redirect/target", - ] - )) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "https://localhost:\(localHTTPBin.port)/redirect/target", + method: .GET, + headers: [ + "X-Target-Redirect-URL": "/redirect/target" + ] + ) + ) guard let request = maybeRequest else { return } XCTAssertThrowsError( @@ -943,14 +1348,18 @@ class HTTPClientTests: XCTestCase { let numberOfRequestsPerThread = 1000 let numberOfParallelWorkers = 5 - final class HTTPServer: ChannelInboundHandler { + final class HTTPServer: ChannelInboundHandler, Sendable { typealias InboundIn = HTTPServerRequestPart typealias OutboundOut = HTTPServerResponsePart func channelRead(context: ChannelHandlerContext, data: NIOAny) { if case .end = self.unwrapInboundIn(data) { - let responseHead = HTTPServerResponsePart.head(.init(version: .init(major: 1, minor: 1), - status: .ok)) + let responseHead = HTTPServerResponsePart.head( + .init( + version: .init(major: 1, minor: 1), + status: .ok + ) + ) context.write(self.wrapOutboundOut(responseHead), promise: nil) context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) } @@ -963,28 +1372,33 @@ class HTTPClientTests: XCTestCase { } var server: Channel? - XCTAssertNoThrow(server = try ServerBootstrap(group: group) - .serverChannelOption(ChannelOptions.socket(.init(SOL_SOCKET), .init(SO_REUSEADDR)), value: 1) - .serverChannelOption(ChannelOptions.backlog, value: .init(numberOfParallelWorkers)) - .childChannelInitializer { channel in - channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: false, - withServerUpgrade: nil, - withErrorHandling: false).flatMap { - channel.pipeline.addHandler(HTTPServer()) + XCTAssertNoThrow( + server = try ServerBootstrap(group: group) + .serverChannelOption(ChannelOptions.socket(.init(SOL_SOCKET), .init(SO_REUSEADDR)), value: 1) + .serverChannelOption(ChannelOptions.backlog, value: .init(numberOfParallelWorkers)) + .childChannelInitializer { channel in + channel.pipeline.configureHTTPServerPipeline( + withPipeliningAssistance: false, + withServerUpgrade: nil, + withErrorHandling: false + ).flatMap { + channel.pipeline.addHandler(HTTPServer()) + } } - } - .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) - .wait()) - defer { + .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) + .wait() + ) + defer { XCTAssertNoThrow(try server?.close().wait()) } + let url = "http://127.0.0.1:\(server?.localAddress?.port ?? -1)/hello" let g = DispatchGroup() + let defaultClient = self.defaultClient! for workerID in 0..]() - for _ in 1...requestCount { - let req = try HTTPClient.Request(url: "https://localhost:\(localHTTPBin.port)/get", method: .GET, headers: ["X-internal-delay": "100"]) - futureResults.append(localClient.execute(request: req)) } - XCTAssertNoThrow(try EventLoopFuture.andAllSucceed(futureResults, on: eventLoop).wait()) - } - func testStressGetHttpsSSLError() throws { let request = try Request(url: "https://localhost:\(self.defaultHTTPBin.port)/wait", method: .GET) let tasks = (1...100).map { _ -> HTTPClient.Task in - self.defaultClient.execute(request: request, delegate: TestHTTPDelegate()) + localClient.execute(request: request, delegate: TestHTTPDelegate()) } - let results = try EventLoopFuture.whenAllComplete(tasks.map { $0.futureResult }, on: self.defaultClient.eventLoopGroup.next()).wait() + let results = try EventLoopFuture.whenAllComplete( + tasks.map { $0.futureResult }, + on: localClient.eventLoopGroup.next() + ).wait() for result in results { switch result { @@ -1179,9 +1634,10 @@ class HTTPClientTests: XCTestCase { // We're speaking TLS to a plain text server. This will cause the handshake to fail but given // that the bytes "HTTP/1.1" aren't the start of a valid TLS packet, we can also get // errSSLPeerProtocolVersion because the first bytes contain the version. - XCTAssert(clientError.status == errSSLHandshakeFail || - clientError.status == errSSLPeerProtocolVersion, - "unexpected NWTLSError with status \(clientError.status)") + XCTAssert( + clientError.status == errSSLHandshakeFail || clientError.status == errSSLPeerProtocolVersion, + "unexpected NWTLSError with status \(clientError.status)" + ) #endif } else { guard let clientError = error as? NIOSSLError, case NIOSSLError.handshakeFailed = clientError else { @@ -1193,6 +1649,97 @@ class HTTPClientTests: XCTestCase { } } + func testSelfSignedCertificateIsRejectedWithCorrectError() throws { + /// key + cert was created with the follwing command: + /// openssl req -x509 -newkey rsa:4096 -keyout self_signed_key.pem -out self_signed_cert.pem -sha256 -days 99999 -nodes -subj '/CN=localhost' + let certPath = Bundle.module.path(forResource: "self_signed_cert", ofType: "pem")! + let keyPath = Bundle.module.path(forResource: "self_signed_key", ofType: "pem")! + let key = try NIOSSLPrivateKey(file: keyPath, format: .pem) + let configuration = try TLSConfiguration.makeServerConfiguration( + certificateChain: NIOSSLCertificate.fromPEMFile(certPath).map { .certificate($0) }, + privateKey: .privateKey(key) + ) + let sslContext = try NIOSSLContext(configuration: configuration) + + let server = ServerBootstrap(group: serverGroup) + .childChannelInitializer { channel in + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(NIOSSLServerHandler(context: sslContext)) + } + } + let serverChannel = try server.bind(host: "localhost", port: 0).wait() + defer { XCTAssertNoThrow(try serverChannel.close().wait()) } + let port = serverChannel.localAddress!.port! + + let config = HTTPClient.Configuration().enableFastFailureModeForTesting() + + let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: config) + defer { XCTAssertNoThrow(try localClient.syncShutdown()) } + XCTAssertThrowsError(try localClient.get(url: "https://localhost:\(port)").wait()) { error in + #if canImport(Network) + guard let nwTLSError = error as? HTTPClient.NWTLSError else { + XCTFail("could not cast \(error) of type \(type(of: error)) to \(HTTPClient.NWTLSError.self)") + return + } + XCTAssertEqual(nwTLSError.status, errSSLBadCert, "unexpected tls error: \(nwTLSError)") + #else + guard let sslError = error as? NIOSSLError, + case .handshakeFailed(.sslError) = sslError + else { + XCTFail("unexpected error \(error)") + return + } + #endif + } + } + + func testSelfSignedCertificateIsRejectedWithCorrectErrorIfRequestDeadlineIsExceeded() throws { + /// key + cert was created with the follwing command: + /// openssl req -x509 -newkey rsa:4096 -keyout self_signed_key.pem -out self_signed_cert.pem -sha256 -days 99999 -nodes -subj '/CN=localhost' + let certPath = Bundle.module.path(forResource: "self_signed_cert", ofType: "pem")! + let keyPath = Bundle.module.path(forResource: "self_signed_key", ofType: "pem")! + let key = try NIOSSLPrivateKey(file: keyPath, format: .pem) + let configuration = try TLSConfiguration.makeServerConfiguration( + certificateChain: NIOSSLCertificate.fromPEMFile(certPath).map { .certificate($0) }, + privateKey: .privateKey(key) + ) + let sslContext = try NIOSSLContext(configuration: configuration) + + let server = ServerBootstrap(group: serverGroup) + .childChannelInitializer { channel in + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(NIOSSLServerHandler(context: sslContext)) + } + } + let serverChannel = try server.bind(host: "localhost", port: 0).wait() + defer { XCTAssertNoThrow(try serverChannel.close().wait()) } + let port = serverChannel.localAddress!.port! + + let config = HTTPClient.Configuration().enableFastFailureModeForTesting() + + let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: config) + defer { XCTAssertNoThrow(try localClient.syncShutdown()) } + + XCTAssertThrowsError( + try localClient.get(url: "https://localhost:\(port)", deadline: .now() + .seconds(2)).wait() + ) { error in + #if canImport(Network) + guard let nwTLSError = error as? HTTPClient.NWTLSError else { + XCTFail("could not cast \(error) of type \(type(of: error)) to \(HTTPClient.NWTLSError.self)") + return + } + XCTAssertEqual(nwTLSError.status, errSSLBadCert, "unexpected tls error: \(nwTLSError)") + #else + guard let sslError = error as? NIOSSLError, + case .handshakeFailed(.sslError) = sslError + else { + XCTFail("unexpected error \(error)") + return + } + #endif + } + } + func testFailingConnectionIsReleased() { let localHTTPBin = HTTPBin(.refuse) let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup)) @@ -1211,37 +1758,22 @@ class HTTPClientTests: XCTestCase { } } - func testResponseDelayGet() throws { - let req = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get", - method: .GET, - headers: ["X-internal-delay": "2000"], - body: nil) - let start = Date() - let response = try! self.defaultClient.execute(request: req).wait() - XCTAssertGreaterThan(Date().timeIntervalSince(start), 2) - XCTAssertEqual(response.status, .ok) - } - - func testIdleTimeoutNoReuse() throws { - var req = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get", method: .GET) - XCTAssertNoThrow(try self.defaultClient.execute(request: req, deadline: .now() + .seconds(2)).wait()) - req.headers.add(name: "X-internal-delay", value: "2500") - try self.defaultClient.eventLoopGroup.next().scheduleTask(in: .milliseconds(250)) {}.futureResult.wait() - XCTAssertNoThrow(try self.defaultClient.execute(request: req).timeout(after: .seconds(10)).wait()) - } - func testStressGetClose() throws { let eventLoop = self.defaultClient.eventLoopGroup.next() let requestCount = 200 var futureResults = [EventLoopFuture]() for _ in 1...requestCount { - let req = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get", - method: .GET, - headers: ["X-internal-delay": "5", "Connection": "close"]) + let req = try HTTPClient.Request( + url: self.defaultHTTPBinURLPrefix + "get", + method: .GET, + headers: ["X-internal-delay": "5", "Connection": "close"] + ) futureResults.append(self.defaultClient.execute(request: req)) } - XCTAssertNoThrow(try EventLoopFuture.andAllComplete(futureResults, on: eventLoop) - .timeout(after: .seconds(10)).wait()) + XCTAssertNoThrow( + try EventLoopFuture.andAllComplete(futureResults, on: eventLoop) + .timeout(after: .seconds(10)).wait() + ) } func testManyConcurrentRequestsWork() { @@ -1256,13 +1788,14 @@ class HTTPClientTests: XCTestCase { for w in 0..]() for i in 1...100 { - let request = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get", method: .GET, headers: ["X-internal-delay": "10"]) + let request = try HTTPClient.Request( + url: self.defaultHTTPBinURLPrefix + "get", + method: .GET, + headers: ["X-internal-delay": "10"] + ) let preference: HTTPClient.EventLoopPreference if i <= 50 { preference = .delegateAndChannel(on: first) @@ -1496,15 +2055,18 @@ class HTTPClientTests: XCTestCase { let seenError = DispatchGroup() seenError.enter() var maybeSecondRequest: EventLoopFuture? - XCTAssertNoThrow(maybeSecondRequest = try el.submit { - let neverSucceedingRequest = localClient.get(url: url) - let secondRequest = neverSucceedingRequest.flatMapError { error in - XCTAssertEqual(.cancelled, error as? HTTPClientError) - seenError.leave() - return localClient.get(url: url) // <== this is the main part, during the error callout, we call back in - } - return secondRequest - }.wait()) + XCTAssertNoThrow( + maybeSecondRequest = try el.submit { + let neverSucceedingRequest = localClient.get(url: url) + let secondRequest = neverSucceedingRequest.flatMapError { error in + XCTAssertEqual(.cancelled, error as? HTTPClientError) + seenError.leave() + // v this is the main part, during the error callout, we call back in + return localClient.get(url: url) + } + return secondRequest + }.wait() + ) guard let secondRequest = maybeSecondRequest else { XCTFail("couldn't get request future") @@ -1530,13 +2092,15 @@ class HTTPClientTests: XCTestCase { XCTAssertNoThrow(try localClient.syncShutdown()) } - XCTAssertEqual(.ok, - try el.flatSubmit { () -> EventLoopFuture in - localClient.get(url: url).flatMap { firstResponse in - XCTAssertEqual(.ok, firstResponse.status) - return localClient.get(url: url) // <== interesting bit here - } - }.wait().status) + XCTAssertEqual( + .ok, + try el.flatSubmit { () -> EventLoopFuture in + localClient.get(url: url).flatMap { firstResponse in + XCTAssertEqual(.ok, firstResponse.status) + return localClient.get(url: url) // <== interesting bit here + } + }.wait().status + ) } func testMakeSecondRequestWhilstFirstIsOngoing() { @@ -1553,11 +2117,11 @@ class HTTPClientTests: XCTestCase { let url = "http://127.0.0.1:\(web.serverPort)" let firstRequest = client.get(url: url) - XCTAssertNoThrow(XCTAssertNotNil(try web.readInbound())) // first request: .head + XCTAssertNoThrow(XCTAssertNotNil(try web.readInbound())) // first request: .head // Now, the first request is ongoing but not complete, let's start a second one let secondRequest = client.get(url: url) - XCTAssertEqual(.end(nil), try web.readInbound()) // first request: .end + XCTAssertEqual(.end(nil), try web.readInbound()) // first request: .end XCTAssertNoThrow(try web.writeOutbound(.head(.init(version: .init(major: 1, minor: 1), status: .ok)))) XCTAssertNoThrow(try web.writeOutbound(.end(nil))) @@ -1565,8 +2129,8 @@ class HTTPClientTests: XCTestCase { XCTAssertEqual(.ok, try firstRequest.wait().status) // Okay, first request done successfully, let's do the second one too. - XCTAssertNoThrow(XCTAssertNotNil(try web.readInbound())) // first request: .head - XCTAssertEqual(.end(nil), try web.readInbound()) // first request: .end + XCTAssertNoThrow(XCTAssertNotNil(try web.readInbound())) // first request: .head + XCTAssertEqual(.end(nil), try web.readInbound()) // first request: .end XCTAssertNoThrow(try web.writeOutbound(.head(.init(version: .init(major: 1, minor: 1), status: .created)))) XCTAssertNoThrow(try web.writeOutbound(.end(nil))) @@ -1577,15 +2141,19 @@ class HTTPClientTests: XCTestCase { // This tests just connecting to a URL where the whole URL is the UNIX domain socket path like // unix:///this/is/my/socket.sock // We don't really have a path component, so we'll have to use "/" - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let localHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) - defer { - XCTAssertNoThrow(try localHTTPBin.shutdown()) + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let localHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) + defer { + XCTAssertNoThrow(try localHTTPBin.shutdown()) + } + let target = "unix://\(path)" + XCTAssertEqual( + ["Yes"[...]], + try self.defaultClient.get(url: target).wait().headers[canonicalForm: "X-Is-This-Slash"] + ) } - let target = "unix://\(path)" - XCTAssertEqual(["Yes"[...]], - try self.defaultClient.get(url: target).wait().headers[canonicalForm: "X-Is-This-Slash"]) - }) + ) } func testUDSSocketAndPath() { @@ -1593,56 +2161,73 @@ class HTTPClientTests: XCTestCase { // // 1. a "base path" which is the path to the UNIX domain socket // 2. an actual path which is the normal path in a regular URL like https://example.com/this/is/the/path - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let localHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) - defer { - XCTAssertNoThrow(try localHTTPBin.shutdown()) - } - guard let target = URL(string: "/echo-uri", relativeTo: URL(string: "unix://\(path)")), - let request = try? Request(url: target) else { - XCTFail("couldn't build URL for request") - return + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let localHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) + defer { + XCTAssertNoThrow(try localHTTPBin.shutdown()) + } + guard let target = URL(string: "/echo-uri", relativeTo: URL(string: "unix://\(path)")), + let request = try? Request(url: target) + else { + XCTFail("couldn't build URL for request") + return + } + XCTAssertEqual( + ["/echo-uri"[...]], + try self.defaultClient.execute(request: request).wait().headers[canonicalForm: "X-Calling-URI"] + ) } - XCTAssertEqual(["/echo-uri"[...]], - try self.defaultClient.execute(request: request).wait().headers[canonicalForm: "X-Calling-URI"]) - }) + ) } func testHTTPPlusUNIX() { // Here, we're testing a URL where the UNIX domain socket is encoded as the host name - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let localHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) - defer { - XCTAssertNoThrow(try localHTTPBin.shutdown()) - } - guard let target = URL(httpURLWithSocketPath: path, uri: "/echo-uri"), - let request = try? Request(url: target) else { - XCTFail("couldn't build URL for request") - return + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let localHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) + defer { + XCTAssertNoThrow(try localHTTPBin.shutdown()) + } + guard let target = URL(httpURLWithSocketPath: path, uri: "/echo-uri"), + let request = try? Request(url: target) + else { + XCTFail("couldn't build URL for request") + return + } + XCTAssertEqual( + ["/echo-uri"[...]], + try self.defaultClient.execute(request: request).wait().headers[canonicalForm: "X-Calling-URI"] + ) } - XCTAssertEqual(["/echo-uri"[...]], - try self.defaultClient.execute(request: request).wait().headers[canonicalForm: "X-Calling-URI"]) - }) + ) } func testHTTPSPlusUNIX() { // Here, we're testing a URL where the UNIX domain socket is encoded as the host name - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let localHTTPBin = HTTPBin(.http1_1(ssl: true), bindTarget: .unixDomainSocket(path)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none)) - defer { - XCTAssertNoThrow(try localClient.syncShutdown()) - XCTAssertNoThrow(try localHTTPBin.shutdown()) - } - guard let target = URL(httpsURLWithSocketPath: path, uri: "/echo-uri"), - let request = try? Request(url: target) else { - XCTFail("couldn't build URL for request") - return + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let localHTTPBin = HTTPBin(.http1_1(ssl: true), bindTarget: .unixDomainSocket(path)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(certificateVerification: .none) + ) + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + XCTAssertNoThrow(try localHTTPBin.shutdown()) + } + guard let target = URL(httpsURLWithSocketPath: path, uri: "/echo-uri"), + let request = try? Request(url: target) + else { + XCTFail("couldn't build URL for request") + return + } + XCTAssertEqual( + ["/echo-uri"[...]], + try localClient.execute(request: request).wait().headers[canonicalForm: "X-Calling-URI"] + ) } - XCTAssertEqual(["/echo-uri"[...]], - try localClient.execute(request: request).wait().headers[canonicalForm: "X-Calling-URI"]) - }) + ) } func testUseExistingConnectionOnDifferentEL() throws { @@ -1656,33 +2241,40 @@ class HTTPClientTests: XCTestCase { let eventLoops = (1...threadCount).map { _ in elg.next() } let request = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get") - let closingRequest = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get", headers: ["Connection": "close"]) + let closingRequest = try HTTPClient.Request( + url: self.defaultHTTPBinURLPrefix + "get", + headers: ["Connection": "close"] + ) for (index, el) in eventLoops.enumerated() { if index.isMultiple(of: 2) { - XCTAssertNoThrow(try localClient.execute(request: request, eventLoop: .delegateAndChannel(on: el)).wait()) + XCTAssertNoThrow( + try localClient.execute(request: request, eventLoop: .delegateAndChannel(on: el)).wait() + ) } else { - XCTAssertNoThrow(try localClient.execute(request: request, eventLoop: .delegateAndChannel(on: el)).wait()) + XCTAssertNoThrow( + try localClient.execute(request: request, eventLoop: .delegateAndChannel(on: el)).wait() + ) XCTAssertNoThrow(try localClient.execute(request: closingRequest, eventLoop: .indifferent).wait()) } } } func testWeRecoverFromServerThatClosesTheConnectionOnUs() { - final class ServerThatAcceptsThenRejects: ChannelInboundHandler { + final class ServerThatAcceptsThenRejects: ChannelInboundHandler, Sendable { typealias InboundIn = HTTPServerRequestPart typealias OutboundOut = HTTPServerResponsePart - let requestNumber: NIOAtomic - let connectionNumber: NIOAtomic + let requestNumber: ManagedAtomic + let connectionNumber: ManagedAtomic - init(requestNumber: NIOAtomic, connectionNumber: NIOAtomic) { + init(requestNumber: ManagedAtomic, connectionNumber: ManagedAtomic) { self.requestNumber = requestNumber self.connectionNumber = connectionNumber } func channelActive(context: ChannelHandlerContext) { - _ = self.connectionNumber.add(1) + _ = self.connectionNumber.loadThenWrappingIncrement(ordering: .relaxed) } func channelRead(context: ChannelHandlerContext, data: NIOAny) { @@ -1692,11 +2284,13 @@ class HTTPClientTests: XCTestCase { case .head, .body: () case .end: - let last = self.requestNumber.add(1) + let last = self.requestNumber.loadThenWrappingIncrement(ordering: .relaxed) switch last { case 0, 2: - context.write(self.wrapOutboundOut(.head(.init(version: .init(major: 1, minor: 1), status: .ok))), - promise: nil) + context.write( + self.wrapOutboundOut(.head(.init(version: .init(major: 1, minor: 1), status: .ok))), + promise: nil + ) context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) case 1: context.close(promise: nil) @@ -1707,22 +2301,26 @@ class HTTPClientTests: XCTestCase { } } - let requestNumber = NIOAtomic.makeAtomic(value: 0) - let connectionNumber = NIOAtomic.makeAtomic(value: 0) - let sharedStateServerHandler = ServerThatAcceptsThenRejects(requestNumber: requestNumber, - connectionNumber: connectionNumber) + let requestNumber = ManagedAtomic(0) + let connectionNumber = ManagedAtomic(0) + let sharedStateServerHandler = ServerThatAcceptsThenRejects( + requestNumber: requestNumber, + connectionNumber: connectionNumber + ) var maybeServer: Channel? - XCTAssertNoThrow(maybeServer = try ServerBootstrap(group: self.serverGroup) - .serverChannelOption(ChannelOptions.socket(.init(SOL_SOCKET), .init(SO_REUSEADDR)), value: 1) - .childChannelInitializer { channel in - channel.pipeline.configureHTTPServerPipeline().flatMap { - // We're deliberately adding a handler which is shared between multiple channels. This is normally - // very verboten but this handler is specially crafted to tolerate this. - channel.pipeline.addHandler(sharedStateServerHandler) + XCTAssertNoThrow( + maybeServer = try ServerBootstrap(group: self.serverGroup) + .serverChannelOption(ChannelOptions.socket(.init(SOL_SOCKET), .init(SO_REUSEADDR)), value: 1) + .childChannelInitializer { channel in + channel.pipeline.configureHTTPServerPipeline().flatMap { + // We're deliberately adding a handler which is shared between multiple channels. This is normally + // very verboten but this handler is specially crafted to tolerate this. + channel.pipeline.addHandler(sharedStateServerHandler) + } } - } - .bind(host: "127.0.0.1", port: 0) - .wait()) + .bind(host: "127.0.0.1", port: 0) + .wait() + ) guard let server = maybeServer else { XCTFail("couldn't create server") return @@ -1737,46 +2335,57 @@ class HTTPClientTests: XCTestCase { XCTAssertNoThrow(try client.syncShutdown()) } - XCTAssertEqual(0, sharedStateServerHandler.connectionNumber.load()) - XCTAssertEqual(0, sharedStateServerHandler.requestNumber.load()) + XCTAssertEqual(0, sharedStateServerHandler.connectionNumber.load(ordering: .relaxed)) + XCTAssertEqual(0, sharedStateServerHandler.requestNumber.load(ordering: .relaxed)) XCTAssertEqual(.ok, try client.get(url: url).wait().status) - XCTAssertEqual(1, sharedStateServerHandler.connectionNumber.load()) - XCTAssertEqual(1, sharedStateServerHandler.requestNumber.load()) + XCTAssertEqual(1, sharedStateServerHandler.connectionNumber.load(ordering: .relaxed)) + XCTAssertEqual(1, sharedStateServerHandler.requestNumber.load(ordering: .relaxed)) XCTAssertThrowsError(try client.get(url: url).wait().status) { error in XCTAssertEqual(.remoteConnectionClosed, error as? HTTPClientError) } - XCTAssertEqual(1, sharedStateServerHandler.connectionNumber.load()) - XCTAssertEqual(2, sharedStateServerHandler.requestNumber.load()) + XCTAssertEqual(1, sharedStateServerHandler.connectionNumber.load(ordering: .relaxed)) + XCTAssertEqual(2, sharedStateServerHandler.requestNumber.load(ordering: .relaxed)) XCTAssertEqual(.ok, try client.get(url: url).wait().status) - XCTAssertEqual(2, sharedStateServerHandler.connectionNumber.load()) - XCTAssertEqual(3, sharedStateServerHandler.requestNumber.load()) + XCTAssertEqual(2, sharedStateServerHandler.connectionNumber.load(ordering: .relaxed)) + XCTAssertEqual(3, sharedStateServerHandler.requestNumber.load(ordering: .relaxed)) } func testPoolClosesIdleConnections() { - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(connectionPool: .init(idleTimeout: .milliseconds(100)))) + let configuration = HTTPClient.Configuration( + certificateVerification: .none, + maximumAllowedIdleTimeInConnectionPool: .milliseconds(100) + ) + + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: configuration + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) } + + // Make sure that the idle timeout of the connection pool is properly propagated + // to the connection pool itself, when using both inits. + XCTAssertEqual(configuration.connectionPool.idleTimeout, .milliseconds(100)) + XCTAssertEqual( + configuration.connectionPool.idleTimeout, + HTTPClient.Configuration( + certificateVerification: .none, + connectionPool: .milliseconds(100), + backgroundActivityLogger: nil + ).connectionPool.idleTimeout + ) + XCTAssertNoThrow(try localClient.get(url: self.defaultHTTPBinURLPrefix + "get").wait()) Thread.sleep(forTimeInterval: 0.2) XCTAssertEqual(self.defaultHTTPBin.activeConnections, 0) } - func testRacePoolIdleConnectionsAndGet() { - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(connectionPool: .init(idleTimeout: .milliseconds(10)))) - defer { - XCTAssertNoThrow(try localClient.syncShutdown()) - } - for _ in 1...500 { - XCTAssertNoThrow(try localClient.get(url: self.defaultHTTPBinURLPrefix + "get").wait()) - Thread.sleep(forTimeInterval: 0.01 + .random(in: -0.05...0.05)) - } - } - func testAvoidLeakingTLSHandshakeCompletionPromise() { - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: .init(timeout: .init(connect: .milliseconds(100)))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init(timeout: .init(connect: .milliseconds(100))) + ) let localHTTPBin = HTTPBin() let port = localHTTPBin.port XCTAssertNoThrow(try localHTTPBin.shutdown()) @@ -1786,7 +2395,12 @@ class HTTPClientTests: XCTestCase { XCTAssertThrowsError(try localClient.get(url: "http://localhost:\(port)").wait()) { error in if isTestingNIOTS() { - XCTAssertEqual(error as? HTTPClientError, .connectTimeout) + #if canImport(Network) + // We can't be more specific than this. + XCTAssertTrue(error is HTTPClient.NWTLSError || error is HTTPClient.NWPOSIXError) + #else + XCTFail("Impossible condition") + #endif } else { XCTAssert(error is NIOConnectionError, "Unexpected error: \(error)") } @@ -1818,9 +2432,14 @@ class HTTPClientTests: XCTestCase { } func testValidationErrorsAreSurfaced() throws { - let request = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get", method: .TRACE, body: .stream { _ in - self.defaultClient.eventLoopGroup.next().makeSucceededFuture(()) - }) + let defaultClient = self.defaultClient! + let request = try HTTPClient.Request( + url: self.defaultHTTPBinURLPrefix + "get", + method: .TRACE, + body: .stream { _ in + defaultClient.eventLoopGroup.next().makeSucceededFuture(()) + } + ) let runningRequest = self.defaultClient.execute(request: request) XCTAssertThrowsError(try runningRequest.wait()) { error in XCTAssertEqual(HTTPClientError.traceRequestWithBody, error as? HTTPClientError) @@ -1838,9 +2457,11 @@ class HTTPClientTests: XCTestCase { private var bodyPartsSeenSoFar = 0 private var atEnd = false - init(headPromise: EventLoopPromise, - bodyPromises: [EventLoopPromise], - endPromise: EventLoopPromise) { + init( + headPromise: EventLoopPromise, + bodyPromises: [EventLoopPromise], + endPromise: EventLoopPromise + ) { self.headPromise = headPromise self.bodyPromises = bodyPromises self.endPromise = endPromise @@ -1856,8 +2477,10 @@ class HTTPClientTests: XCTestCase { self.bodyPartsSeenSoFar += 1 self.bodyPromises.dropFirst(myNumber).first?.succeed(bytes) ?? XCTFail("ouch, too many chunks") case .end: - context.write(self.wrapOutboundOut(.head(.init(version: .init(major: 1, minor: 1), status: .ok))), - promise: nil) + context.write( + self.wrapOutboundOut(.head(.init(version: .init(major: 1, minor: 1), status: .ok))), + promise: nil + ) context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: self.endPromise) self.atEnd = true } @@ -1870,8 +2493,8 @@ class HTTPClientTests: XCTestCase { struct NotFulfilledError: Error {} self.headPromise.fail(NotFulfilledError()) - self.bodyPromises.forEach { - $0.fail(NotFulfilledError()) + for promise in self.bodyPromises { + promise.fail(NotFulfilledError()) } self.endPromise.fail(NotFulfilledError()) } @@ -1892,12 +2515,16 @@ class HTTPClientTests: XCTestCase { let streamWriterPromise = group.next().makePromise(of: HTTPClient.Body.StreamWriter.self) func makeServer() -> Channel? { - return try? ServerBootstrap(group: group) + try? ServerBootstrap(group: group) .childChannelInitializer { channel in - channel.pipeline.configureHTTPServerPipeline().flatMap { - channel.pipeline.addHandler(HTTPServer(headPromise: headPromise, - bodyPromises: bodyPromises, - endPromise: endPromise)) + channel.pipeline.configureHTTPServerPipeline().flatMapThrowing { + try channel.pipeline.syncOperations.addHandler( + HTTPServer( + headPromise: headPromise, + bodyPromises: bodyPromises, + endPromise: endPromise + ) + ) } } .serverChannelOption(ChannelOptions.socket(.init(SOL_SOCKET), .init(SO_REUSEADDR)), value: 1) @@ -1910,13 +2537,15 @@ class HTTPClientTests: XCTestCase { return nil } - return try? HTTPClient.Request(url: "http://\(localAddress.ipAddress!):\(localAddress.port!)", - method: .POST, - headers: ["transfer-encoding": "chunked"], - body: .stream { streamWriter in - streamWriterPromise.succeed(streamWriter) - return sentOffAllBodyPartsPromise.futureResult - }) + return try? HTTPClient.Request( + url: "http://\(localAddress.ipAddress!):\(localAddress.port!)", + method: .POST, + headers: ["transfer-encoding": "chunked"], + body: .stream { streamWriter in + streamWriterPromise.succeed(streamWriter) + return sentOffAllBodyPartsPromise.futureResult + } + ) } guard let server = makeServer(), let request = makeRequest(server: server) else { @@ -1948,35 +2577,46 @@ class HTTPClientTests: XCTestCase { } func testUploadStreamingCallinToleratedFromOtsideEL() throws { - let request = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get", method: .POST, body: .stream(length: 4) { writer in - let promise = self.defaultClient.eventLoopGroup.next().makePromise(of: Void.self) - // We have to toleare callins from any thread - DispatchQueue(label: "upload-streaming").async { - writer.write(.byteBuffer(ByteBuffer(string: "1234"))).whenComplete { _ in - promise.succeed(()) + let defaultClient = self.defaultClient! + let request = try HTTPClient.Request( + url: self.defaultHTTPBinURLPrefix + "get", + method: .POST, + body: .stream(contentLength: 4) { writer in + let promise = defaultClient.eventLoopGroup.next().makePromise(of: Void.self) + // We have to toleare callins from any thread + DispatchQueue(label: "upload-streaming").async { + writer.write(.byteBuffer(ByteBuffer(string: "1234"))).whenComplete { _ in + promise.succeed(()) + } } + return promise.futureResult } - return promise.futureResult - }) + ) XCTAssertNoThrow(try self.defaultClient.execute(request: request).wait()) } func testWeHandleUsSendingACloseHeaderCorrectly() { - guard let req1 = try? Request(url: self.defaultHTTPBinURLPrefix + "stats", - method: .GET, - headers: ["connection": "close"]), + guard + let req1 = try? Request( + url: self.defaultHTTPBinURLPrefix + "stats", + method: .GET, + headers: ["connection": "close"] + ), let statsBytes1 = try? self.defaultClient.execute(request: req1).wait().body, - let stats1 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes1) else { + let stats1 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes1) + else { XCTFail("request 1 didn't work") return } guard let statsBytes2 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, - let stats2 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes2) else { + let stats2 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes2) + else { XCTFail("request 2 didn't work") return } guard let statsBytes3 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, - let stats3 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes3) else { + let stats3 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes3) + else { XCTFail("request 3 didn't work") return } @@ -1992,21 +2632,27 @@ class HTTPClientTests: XCTestCase { } func testWeHandleUsReceivingACloseHeaderCorrectly() { - guard let req1 = try? Request(url: self.defaultHTTPBinURLPrefix + "stats", - method: .GET, - headers: ["X-Send-Back-Header-Connection": "close"]), + guard + let req1 = try? Request( + url: self.defaultHTTPBinURLPrefix + "stats", + method: .GET, + headers: ["X-Send-Back-Header-Connection": "close"] + ), let statsBytes1 = try? self.defaultClient.execute(request: req1).wait().body, - let stats1 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes1) else { + let stats1 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes1) + else { XCTFail("request 1 didn't work") return } guard let statsBytes2 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, - let stats2 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes2) else { + let stats2 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes2) + else { XCTFail("request 2 didn't work") return } guard let statsBytes3 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, - let stats3 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes3) else { + let stats3 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes3) + else { XCTFail("request 3 didn't work") return } @@ -2023,22 +2669,32 @@ class HTTPClientTests: XCTestCase { func testWeHandleUsSendingACloseHeaderAmongstOtherConnectionHeadersCorrectly() { for closeHeader in [("connection", "close"), ("CoNneCTION", "ClOSe")] { - guard let req1 = try? Request(url: self.defaultHTTPBinURLPrefix + "stats", - method: .GET, - headers: ["X-Send-Back-Header-\(closeHeader.0)": - "foo,\(closeHeader.1),bar"]), + guard + let req1 = try? Request( + url: self.defaultHTTPBinURLPrefix + "stats", + method: .GET, + headers: [ + "X-Send-Back-Header-\(closeHeader.0)": + "foo,\(closeHeader.1),bar" + ] + ), let statsBytes1 = try? self.defaultClient.execute(request: req1).wait().body, - let stats1 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes1) else { + let stats1 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes1) + else { XCTFail("request 1 didn't work") return } - guard let statsBytes2 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, - let stats2 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes2) else { + guard + let statsBytes2 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, + let stats2 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes2) + else { XCTFail("request 2 didn't work") return } - guard let statsBytes3 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, - let stats3 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes3) else { + guard + let statsBytes3 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, + let stats3 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes3) + else { XCTFail("request 3 didn't work") return } @@ -2055,22 +2711,32 @@ class HTTPClientTests: XCTestCase { func testWeHandleUsReceivingACloseHeaderAmongstOtherConnectionHeadersCorrectly() { for closeHeader in [("connection", "close"), ("CoNneCTION", "ClOSe")] { - guard let req1 = try? Request(url: self.defaultHTTPBinURLPrefix + "stats", - method: .GET, - headers: ["X-Send-Back-Header-\(closeHeader.0)": - "foo,\(closeHeader.1),bar"]), + guard + let req1 = try? Request( + url: self.defaultHTTPBinURLPrefix + "stats", + method: .GET, + headers: [ + "X-Send-Back-Header-\(closeHeader.0)": + "foo,\(closeHeader.1),bar" + ] + ), let statsBytes1 = try? self.defaultClient.execute(request: req1).wait().body, - let stats1 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes1) else { + let stats1 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes1) + else { XCTFail("request 1 didn't work") return } - guard let statsBytes2 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, - let stats2 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes2) else { + guard + let statsBytes2 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, + let stats2 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes2) + else { XCTFail("request 2 didn't work") return } - guard let statsBytes3 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, - let stats3 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes3) else { + guard + let statsBytes3 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, + let stats3 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes3) + else { XCTFail("request 3 didn't work") return } @@ -2086,31 +2752,30 @@ class HTTPClientTests: XCTestCase { } func testLoggingCorrectlyAttachesRequestInformationEvenAfterDuringRedirect() { - let logStore = CollectEverythingLogHandler.LogStore() - - var logger = Logger(label: "\(#function)", factory: { _ in - CollectEverythingLogHandler(logStore: logStore) - }) - logger.logLevel = .trace + var (logStore, logger) = InMemoryLogHandler.makeLogger(logLevel: .trace) logger[metadataKey: "custom-request-id"] = "abcd" var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request( - url: "http://localhost:\(self.defaultHTTPBin.port)/redirect/target", - method: .GET, - headers: [ - "X-Target-Redirect-URL": "/get", - ] - )) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost:\(self.defaultHTTPBin.port)/redirect/target", + method: .GET, + headers: [ + "X-Target-Redirect-URL": "/get" + ] + ) + ) guard let request = maybeRequest else { return } - XCTAssertNoThrow(try self.defaultClient.execute( - request: request, - eventLoop: .indifferent, - deadline: nil, - logger: logger - ).wait()) - let logs = logStore.allEntries + XCTAssertNoThrow( + try self.defaultClient.execute( + request: request, + eventLoop: .indifferent, + deadline: nil, + logger: logger + ).wait() + ) + let logs = logStore.entries XCTAssertTrue(logs.allSatisfy { $0.metadata["custom-request-id"] == "abcd" }) @@ -2128,316 +2793,392 @@ class HTTPClientTests: XCTestCase { XCTAssertGreaterThan(secondRequestLogs.count, 0) XCTAssertTrue(secondRequestLogs.allSatisfy { $0.metadata["ahc-request-id"] == lastRequestID }) - logs.forEach { print($0) } + for log in logs { print(log) } } func testLoggingCorrectlyAttachesRequestInformation() { - let logStore = CollectEverythingLogHandler.LogStore() + let logStore = InMemoryLogHandler() - var loggerYolo001 = Logger(label: "\(#function)", factory: { _ in - CollectEverythingLogHandler(logStore: logStore) - }) + var loggerYolo001 = Logger( + label: "\(#function)", + factory: { _ in + logStore + } + ) loggerYolo001.logLevel = .trace loggerYolo001[metadataKey: "yolo-request-id"] = "yolo-001" - var loggerACME002 = Logger(label: "\(#function)", factory: { _ in - CollectEverythingLogHandler(logStore: logStore) - }) + var loggerACME002 = Logger( + label: "\(#function)", + factory: { _ in + logStore + } + ) loggerACME002.logLevel = .trace loggerACME002[metadataKey: "acme-request-id"] = "acme-002" guard let request1 = try? HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get"), - let request2 = try? HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "stats"), - let request3 = try? HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "ok") else { + let request2 = try? HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "stats"), + let request3 = try? HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "ok") + else { XCTFail("bad stuff, can't even make request structures") return } // === Request 1 (Yolo001) - XCTAssertNoThrow(try self.defaultClient.execute(request: request1, - eventLoop: .indifferent, - deadline: nil, - logger: loggerYolo001).wait()) - let logsAfterReq1 = logStore.allEntries - logStore.allEntries = [] + XCTAssertNoThrow( + try self.defaultClient.execute( + request: request1, + eventLoop: .indifferent, + deadline: nil, + logger: loggerYolo001 + ).wait() + ) + let logsAfterReq1 = logStore.entries + logStore.clear() // === Request 2 (Yolo001) - XCTAssertNoThrow(try self.defaultClient.execute(request: request2, - eventLoop: .indifferent, - deadline: nil, - logger: loggerYolo001).wait()) - let logsAfterReq2 = logStore.allEntries - logStore.allEntries = [] + XCTAssertNoThrow( + try self.defaultClient.execute( + request: request2, + eventLoop: .indifferent, + deadline: nil, + logger: loggerYolo001 + ).wait() + ) + let logsAfterReq2 = logStore.entries + logStore.clear() // === Request 3 (ACME002) - XCTAssertNoThrow(try self.defaultClient.execute(request: request3, - eventLoop: .indifferent, - deadline: nil, - logger: loggerACME002).wait()) - let logsAfterReq3 = logStore.allEntries - logStore.allEntries = [] + XCTAssertNoThrow( + try self.defaultClient.execute( + request: request3, + eventLoop: .indifferent, + deadline: nil, + logger: loggerACME002 + ).wait() + ) + let logsAfterReq3 = logStore.entries + logStore.clear() // === Assertions XCTAssertGreaterThan(logsAfterReq1.count, 0) XCTAssertGreaterThan(logsAfterReq2.count, 0) XCTAssertGreaterThan(logsAfterReq3.count, 0) - XCTAssert(logsAfterReq1.allSatisfy { entry in - if let httpRequestMetadata = entry.metadata["ahc-request-id"], - let yoloRequestID = entry.metadata["yolo-request-id"] { - XCTAssertNil(entry.metadata["acme-request-id"]) - XCTAssertEqual("yolo-001", yoloRequestID) - XCTAssertNotNil(Int(httpRequestMetadata)) - return true - } else { - XCTFail("log message doesn't contain the right IDs: \(entry)") - return false - } - }) - XCTAssert(logsAfterReq1.contains { entry in - // Since a new connection must be created first we expect that the request is queued - // and log message describing this is emitted. - entry.message == "Request was queued (waiting for a connection to become available)" - && entry.level == .debug - }) - XCTAssert(logsAfterReq1.contains { entry in - // After the new connection was created we expect a log message that describes that the - // request was scheduled on a connection. The connection id must be set from here on. - entry.message == "Request was scheduled on connection" - && entry.level == .debug - && entry.metadata["ahc-connection-id"] != nil - }) - - XCTAssert(logsAfterReq2.allSatisfy { entry in - if let httpRequestMetadata = entry.metadata["ahc-request-id"], - let yoloRequestID = entry.metadata["yolo-request-id"] { - XCTAssertNil(entry.metadata["acme-request-id"]) - XCTAssertEqual("yolo-001", yoloRequestID) - XCTAssertNotNil(Int(httpRequestMetadata)) - return true - } else { - XCTFail("log message doesn't contain the right IDs: \(entry)") - return false - } - }) - XCTAssertFalse(logsAfterReq2.contains { entry in - entry.message == "Request was queued (waiting for a connection to become available)" - }) - XCTAssert(logsAfterReq2.contains { entry in - entry.message == "Request was scheduled on connection" - && entry.level == .debug - && entry.metadata["ahc-connection-id"] != nil - }) - - XCTAssert(logsAfterReq3.allSatisfy { entry in - if let httpRequestMetadata = entry.metadata["ahc-request-id"], - let acmeRequestID = entry.metadata["acme-request-id"] { - XCTAssertNil(entry.metadata["yolo-request-id"]) - XCTAssertEqual("acme-002", acmeRequestID) - XCTAssertNotNil(Int(httpRequestMetadata)) - return true - } else { - XCTFail("log message doesn't contain the right IDs: \(entry)") - return false + XCTAssert( + logsAfterReq1.allSatisfy { entry in + if let httpRequestMetadata = entry.metadata["ahc-request-id"], + let yoloRequestID = entry.metadata["yolo-request-id"] + { + XCTAssertNil(entry.metadata["acme-request-id"]) + XCTAssertEqual("yolo-001", yoloRequestID) + XCTAssertNotNil(Int("\(httpRequestMetadata)")) + return true + } else { + XCTFail("log message doesn't contain the right IDs: \(entry)") + return false + } + } + ) + XCTAssert( + logsAfterReq1.contains { entry in + // Since a new connection must be created first we expect that the request is queued + // and log message describing this is emitted. + entry.message == "Request was queued (waiting for a connection to become available)" + && entry.level == .debug + } + ) + XCTAssert( + logsAfterReq1.contains { entry in + // After the new connection was created we expect a log message that describes that the + // request was scheduled on a connection. The connection id must be set from here on. + entry.message == "Request was scheduled on connection" + && entry.level == .debug + && entry.metadata["ahc-connection-id"] != nil + } + ) + + XCTAssert( + logsAfterReq2.allSatisfy { entry in + if let httpRequestMetadata = entry.metadata["ahc-request-id"], + let yoloRequestID = entry.metadata["yolo-request-id"] + { + XCTAssertNil(entry.metadata["acme-request-id"]) + XCTAssertEqual("yolo-001", yoloRequestID) + XCTAssertNotNil(Int("\(httpRequestMetadata)")) + return true + } else { + XCTFail("log message doesn't contain the right IDs: \(entry)") + return false + } + } + ) + XCTAssertFalse( + logsAfterReq2.contains { entry in + entry.message == "Request was queued (waiting for a connection to become available)" + } + ) + XCTAssert( + logsAfterReq2.contains { entry in + entry.message == "Request was scheduled on connection" + && entry.level == .debug + && entry.metadata["ahc-connection-id"] != nil + } + ) + + XCTAssert( + logsAfterReq3.allSatisfy { entry in + if let httpRequestMetadata = entry.metadata["ahc-request-id"], + let acmeRequestID = entry.metadata["acme-request-id"] + { + XCTAssertNil(entry.metadata["yolo-request-id"]) + XCTAssertEqual("acme-002", acmeRequestID) + XCTAssertNotNil(Int("\(httpRequestMetadata)")) + return true + } else { + XCTFail("log message doesn't contain the right IDs: \(entry)") + return false + } + } + ) + XCTAssertFalse( + logsAfterReq3.contains { entry in + entry.message == "Request was queued (waiting for a connection to become available)" + } + ) + XCTAssert( + logsAfterReq3.contains { entry in + entry.message == "Request was scheduled on connection" + && entry.level == .debug + && entry.metadata["ahc-connection-id"] != nil } - }) - XCTAssertFalse(logsAfterReq3.contains { entry in - entry.message == "Request was queued (waiting for a connection to become available)" - }) - XCTAssert(logsAfterReq3.contains { entry in - entry.message == "Request was scheduled on connection" - && entry.level == .debug - && entry.metadata["ahc-connection-id"] != nil - }) + ) } func testNothingIsLoggedAtInfoOrHigher() { - let logStore = CollectEverythingLogHandler.LogStore() - - var logger = Logger(label: "\(#function)", factory: { _ in - CollectEverythingLogHandler(logStore: logStore) - }) - logger.logLevel = .info + let (logStore, logger) = InMemoryLogHandler.makeLogger(logLevel: .info) guard let request1 = try? HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get"), - let request2 = try? HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "stats") else { + let request2 = try? HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "stats") + else { XCTFail("bad stuff, can't even make request structures") return } // === Request 1 - XCTAssertNoThrow(try self.defaultClient.execute(request: request1, - eventLoop: .indifferent, - deadline: nil, - logger: logger).wait()) - XCTAssertEqual(0, logStore.allEntries.count) + XCTAssertNoThrow( + try self.defaultClient.execute( + request: request1, + eventLoop: .indifferent, + deadline: nil, + logger: logger + ).wait() + ) + XCTAssertEqual(0, logStore.entries.count) // === Request 2 - XCTAssertNoThrow(try self.defaultClient.execute(request: request2, - eventLoop: .indifferent, - deadline: nil, - logger: logger).wait()) - XCTAssertEqual(0, logStore.allEntries.count) + XCTAssertNoThrow( + try self.defaultClient.execute( + request: request2, + eventLoop: .indifferent, + deadline: nil, + logger: logger + ).wait() + ) + XCTAssertEqual(0, logStore.entries.count) // === Synthesized Request - XCTAssertNoThrow(try self.defaultClient.execute(.GET, - url: self.defaultHTTPBinURLPrefix + "get", - body: nil, - deadline: nil, - logger: logger).wait()) - XCTAssertEqual(0, logStore.allEntries.count) + XCTAssertNoThrow( + try self.defaultClient.execute( + .GET, + url: self.defaultHTTPBinURLPrefix + "get", + body: nil, + deadline: nil, + logger: logger + ).wait() + ) + XCTAssertEqual(0, logStore.entries.count) - XCTAssertEqual(0, self.backgroundLogStore.allEntries.filter { $0.level >= .info }.count) + XCTAssertEqual(0, self.backgroundLogStore.entries.filter { $0.level >= .info }.count) // === Synthesized Socket Path Request - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let backgroundLogStore = CollectEverythingLogHandler.LogStore() - var backgroundLogger = Logger(label: "\(#function)", factory: { _ in - CollectEverythingLogHandler(logStore: backgroundLogStore) - }) - backgroundLogger.logLevel = .trace - - let localSocketPathHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - backgroundActivityLogger: backgroundLogger) - defer { - XCTAssertNoThrow(try localClient.syncShutdown()) - XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) - } + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let (backgroundLogStore, backgroundLogger) = InMemoryLogHandler.makeLogger(logLevel: .trace) + + let localSocketPathHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + backgroundActivityLogger: backgroundLogger + ) + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) + } - XCTAssertNoThrow(try localClient.execute(.GET, - socketPath: path, - urlPath: "get", - body: nil, - deadline: nil, - logger: logger).wait()) - XCTAssertEqual(0, logStore.allEntries.count) + XCTAssertNoThrow( + try localClient.execute( + .GET, + socketPath: path, + urlPath: "get", + body: nil, + deadline: nil, + logger: logger + ).wait() + ) + XCTAssertEqual(0, logStore.entries.count) - XCTAssertEqual(0, backgroundLogStore.allEntries.filter { $0.level >= .info }.count) - }) + XCTAssertEqual(0, backgroundLogStore.entries.filter { $0.level >= .info }.count) + } + ) // === Synthesized Secure Socket Path Request - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let backgroundLogStore = CollectEverythingLogHandler.LogStore() - var backgroundLogger = Logger(label: "\(#function)", factory: { _ in - CollectEverythingLogHandler(logStore: backgroundLogStore) - }) - backgroundLogger.logLevel = .trace - - let localSocketPathHTTPBin = HTTPBin(.http1_1(ssl: true), bindTarget: .unixDomainSocket(path)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none), - backgroundActivityLogger: backgroundLogger) - defer { - XCTAssertNoThrow(try localClient.syncShutdown()) - XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) - } + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let (backgroundLogStore, backgroundLogger) = InMemoryLogHandler.makeLogger(logLevel: .trace) + + let localSocketPathHTTPBin = HTTPBin(.http1_1(ssl: true), bindTarget: .unixDomainSocket(path)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(certificateVerification: .none), + backgroundActivityLogger: backgroundLogger + ) + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) + } - XCTAssertNoThrow(try localClient.execute(.GET, - secureSocketPath: path, - urlPath: "get", - body: nil, - deadline: nil, - logger: logger).wait()) - XCTAssertEqual(0, logStore.allEntries.count) + XCTAssertNoThrow( + try localClient.execute( + .GET, + secureSocketPath: path, + urlPath: "get", + body: nil, + deadline: nil, + logger: logger + ).wait() + ) + XCTAssertEqual(0, logStore.entries.count) - XCTAssertEqual(0, backgroundLogStore.allEntries.filter { $0.level >= .info }.count) - }) + XCTAssertEqual(0, backgroundLogStore.entries.filter { $0.level >= .info }.count) + } + ) } func testAllMethodsLog() { func checkExpectationsWithLogger(type: String, _ body: (Logger, String) throws -> T) throws -> T { - let logStore = CollectEverythingLogHandler.LogStore() - - var logger = Logger(label: "\(#function)", factory: { _ in - CollectEverythingLogHandler(logStore: logStore) - }) - logger.logLevel = .trace + var (logStore, logger) = InMemoryLogHandler.makeLogger(logLevel: .trace) logger[metadataKey: "req"] = "yo-\(type)" let url = "not-found/request/\(type))" let result = try body(logger, url) - XCTAssertGreaterThan(logStore.allEntries.count, 0) - logStore.allEntries.forEach { entry in + XCTAssertGreaterThan(logStore.entries.count, 0) + for entry in logStore.entries { XCTAssertEqual("yo-\(type)", entry.metadata["req"] ?? "n/a") - XCTAssertNotNil(Int(entry.metadata["ahc-request-id"] ?? "n/a")) + XCTAssertNotNil(Int(entry.metadata["ahc-request-id"]?.description ?? "n/a")) } return result } - XCTAssertEqual(.notFound, try checkExpectationsWithLogger(type: "GET") { logger, url in - try self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() - }.status) + XCTAssertEqual( + .notFound, + try checkExpectationsWithLogger(type: "GET") { logger, url in + try self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() + }.status + ) - XCTAssertEqual(.notFound, try checkExpectationsWithLogger(type: "PUT") { logger, url in - try self.defaultClient.put(url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() - }.status) + XCTAssertEqual( + .notFound, + try checkExpectationsWithLogger(type: "PUT") { logger, url in + try self.defaultClient.put(url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() + }.status + ) - XCTAssertEqual(.notFound, try checkExpectationsWithLogger(type: "POST") { logger, url in - try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() - }.status) + XCTAssertEqual( + .notFound, + try checkExpectationsWithLogger(type: "POST") { logger, url in + try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() + }.status + ) - XCTAssertEqual(.notFound, try checkExpectationsWithLogger(type: "DELETE") { logger, url in - try self.defaultClient.delete(url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() - }.status) + XCTAssertEqual( + .notFound, + try checkExpectationsWithLogger(type: "DELETE") { logger, url in + try self.defaultClient.delete(url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() + }.status + ) - XCTAssertEqual(.notFound, try checkExpectationsWithLogger(type: "PATCH") { logger, url in - try self.defaultClient.patch(url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() - }.status) + XCTAssertEqual( + .notFound, + try checkExpectationsWithLogger(type: "PATCH") { logger, url in + try self.defaultClient.patch(url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() + }.status + ) - XCTAssertEqual(.notFound, try checkExpectationsWithLogger(type: "CHECKOUT") { logger, url in - try self.defaultClient.execute(.CHECKOUT, url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() - }.status) + XCTAssertEqual( + .notFound, + try checkExpectationsWithLogger(type: "CHECKOUT") { logger, url in + try self.defaultClient.execute(.CHECKOUT, url: self.defaultHTTPBinURLPrefix + url, logger: logger) + .wait() + }.status + ) // No background activity expected here. - XCTAssertEqual(0, self.backgroundLogStore.allEntries.filter { $0.level >= .debug }.count) - - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let backgroundLogStore = CollectEverythingLogHandler.LogStore() - var backgroundLogger = Logger(label: "\(#function)", factory: { _ in - CollectEverythingLogHandler(logStore: backgroundLogStore) - }) - backgroundLogger.logLevel = .trace - - let localSocketPathHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - backgroundActivityLogger: backgroundLogger) - defer { - XCTAssertNoThrow(try localClient.syncShutdown()) - XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) - } + XCTAssertEqual(0, self.backgroundLogStore.entries.filter { $0.level >= .debug }.count) - XCTAssertEqual(.notFound, try checkExpectationsWithLogger(type: "GET") { logger, url in - try localClient.execute(socketPath: path, urlPath: url, logger: logger).wait() - }.status) + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let (backgroundLogStore, backgroundLogger) = InMemoryLogHandler.makeLogger(logLevel: .trace) - // No background activity expected here. - XCTAssertEqual(0, backgroundLogStore.allEntries.filter { $0.level >= .debug }.count) - }) + let localSocketPathHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + backgroundActivityLogger: backgroundLogger + ) + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) + } - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let backgroundLogStore = CollectEverythingLogHandler.LogStore() - var backgroundLogger = Logger(label: "\(#function)", factory: { _ in - CollectEverythingLogHandler(logStore: backgroundLogStore) - }) - backgroundLogger.logLevel = .trace + XCTAssertEqual( + .notFound, + try checkExpectationsWithLogger(type: "GET") { logger, url in + try localClient.execute(socketPath: path, urlPath: url, logger: logger).wait() + }.status + ) - let localSocketPathHTTPBin = HTTPBin(.http1_1(ssl: true), bindTarget: .unixDomainSocket(path)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none), - backgroundActivityLogger: backgroundLogger) - defer { - XCTAssertNoThrow(try localClient.syncShutdown()) - XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) + // No background activity expected here. + XCTAssertEqual(0, backgroundLogStore.entries.filter { $0.level >= .debug }.count) } + ) + + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let (backgroundLogStore, backgroundLogger) = InMemoryLogHandler.makeLogger(logLevel: .trace) + + let localSocketPathHTTPBin = HTTPBin(.http1_1(ssl: true), bindTarget: .unixDomainSocket(path)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(certificateVerification: .none), + backgroundActivityLogger: backgroundLogger + ) + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) + } - XCTAssertEqual(.notFound, try checkExpectationsWithLogger(type: "GET") { logger, url in - try localClient.execute(secureSocketPath: path, urlPath: url, logger: logger).wait() - }.status) + XCTAssertEqual( + .notFound, + try checkExpectationsWithLogger(type: "GET") { logger, url in + try localClient.execute(secureSocketPath: path, urlPath: url, logger: logger).wait() + }.status + ) - // No background activity expected here. - XCTAssertEqual(0, backgroundLogStore.allEntries.filter { $0.level >= .debug }.count) - }) + // No background activity expected here. + XCTAssertEqual(0, backgroundLogStore.entries.filter { $0.level >= .debug }.count) + } + ) } func testClosingIdleConnectionsInPoolLogsInTheBackground() { @@ -2445,17 +3186,20 @@ class HTTPClientTests: XCTestCase { XCTAssertNoThrow(try self.defaultClient.syncShutdown()) - XCTAssertGreaterThanOrEqual(self.backgroundLogStore.allEntries.count, 0) - XCTAssert(self.backgroundLogStore.allEntries.contains { entry in - entry.message == "Shutting down connection pool" - }) - XCTAssert(self.backgroundLogStore.allEntries.allSatisfy { entry in - entry.metadata["ahc-request-id"] == nil && - entry.metadata["ahc-request"] == nil && - entry.metadata["ahc-pool-key"] != nil - }) + XCTAssertGreaterThanOrEqual(self.backgroundLogStore.entries.count, 0) + XCTAssert( + self.backgroundLogStore.entries.contains { entry in + entry.message == "Shutting down connection pool" + } + ) + XCTAssert( + self.backgroundLogStore.entries.allSatisfy { entry in + entry.metadata["ahc-request-id"] == nil && entry.metadata["ahc-request"] == nil + && entry.metadata["ahc-pool-key"] != nil + } + ) - self.defaultClient = nil // so it doesn't get shut down again. + self.defaultClient = nil // so it doesn't get shut down again. } func testUploadStreamingNoLength() throws { @@ -2480,8 +3224,8 @@ class HTTPClientTests: XCTestCase { XCTFail("Unexpected part") } - XCTAssertNoThrow(try server.readInbound()) // .body - XCTAssertNoThrow(try server.readInbound()) // .end + XCTAssertNoThrow(try server.readInbound()) // .body + XCTAssertNoThrow(try server.readInbound()) // .end XCTAssertNoThrow(try server.writeOutbound(.head(.init(version: .init(major: 1, minor: 1), status: .ok)))) XCTAssertNoThrow(try server.writeOutbound(.end(nil))) @@ -2490,17 +3234,19 @@ class HTTPClientTests: XCTestCase { } func testConnectErrorPropagatedToDelegate() throws { - class TestDelegate: HTTPClientResponseDelegate { + final class TestDelegate: HTTPClientResponseDelegate { typealias Response = Void - var error: Error? + let error = NIOLockedValueBox(nil) func didFinishRequest(task: HTTPClient.Task) throws {} func didReceiveError(task: HTTPClient.Task, _ error: Error) { - self.error = error + self.error.withLockedValue { $0 = error } } } - let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(timeout: .init(connect: .milliseconds(10)))) + let httpClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init(timeout: .init(connect: .milliseconds(10))) + ) defer { XCTAssertNoThrow(try httpClient.syncShutdown()) @@ -2511,13 +3257,13 @@ class HTTPClientTests: XCTestCase { let delegate = TestDelegate() XCTAssertThrowsError(try httpClient.execute(request: request, delegate: delegate).wait()) { - XCTAssertEqual(.connectTimeout, $0 as? HTTPClientError) - XCTAssertEqual(.connectTimeout, delegate.error as? HTTPClientError) + XCTAssertEqualTypeAndValue($0, HTTPClientError.connectTimeout) + XCTAssertEqualTypeAndValue(delegate.error.withLockedValue { $0 }, HTTPClientError.connectTimeout) } } func testDelegateCallinsTolerateRandomEL() throws { - class TestDelegate: HTTPClientResponseDelegate { + final class TestDelegate: HTTPClientResponseDelegate { typealias Response = Void let eventLoop: EventLoop @@ -2526,11 +3272,11 @@ class HTTPClientTests: XCTestCase { } func didReceiveHead(task: HTTPClient.Task, _: HTTPResponseHead) -> EventLoopFuture { - return self.eventLoop.makeSucceededFuture(()) + self.eventLoop.makeSucceededFuture(()) } func didReceiveBodyPart(task: HTTPClient.Task, _: ByteBuffer) -> EventLoopFuture { - return self.eventLoop.makeSucceededFuture(()) + self.eventLoop.makeSucceededFuture(()) } func didFinishRequest(task: HTTPClient.Task) throws {} @@ -2553,8 +3299,8 @@ class HTTPClientTests: XCTestCase { let request = try HTTPClient.Request(url: "http://localhost:\(httpServer.serverPort)/") let future = httpClient.execute(request: request, delegate: delegate) - XCTAssertNoThrow(try httpServer.readInbound()) // .head - XCTAssertNoThrow(try httpServer.readInbound()) // .end + XCTAssertNoThrow(try httpServer.readInbound()) // .head + XCTAssertNoThrow(try httpServer.readInbound()) // .end XCTAssertNoThrow(try httpServer.writeOutbound(.head(.init(version: .init(major: 1, minor: 1), status: .ok)))) XCTAssertNoThrow(try httpServer.writeOutbound(.body(.byteBuffer(ByteBuffer(string: "1234"))))) @@ -2563,18 +3309,58 @@ class HTTPClientTests: XCTestCase { XCTAssertNoThrow(try future.wait()) } + func testDelegateGetsErrorsFromCreatingRequestBag() throws { + // We want to test that we propagate errors to the delegate from failures to construct the + // request bag. Those errors only come from invalid headers. + final class TestDelegate: HTTPClientResponseDelegate, Sendable { + typealias Response = Void + let error: NIOLockedValueBox = .init(nil) + func didFinishRequest(task: HTTPClient.Task) throws {} + func didReceiveError(task: HTTPClient.Task, _ error: Error) { + self.error.withLockedValue { $0 = error } + } + } + + let httpClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup) + ) + + defer { + XCTAssertNoThrow(try httpClient.syncShutdown()) + } + + // 198.51.100.254 is reserved for documentation only + var request = try HTTPClient.Request(url: "http://198.51.100.254:65535/get") + request.headers.replaceOrAdd(name: "Not-ASCII", value: "not-fine\n") + let delegate = TestDelegate() + + XCTAssertThrowsError(try httpClient.execute(request: request, delegate: delegate).wait()) { + XCTAssertEqualTypeAndValue($0, HTTPClientError.invalidHeaderFieldValues(["not-fine\n"])) + XCTAssertEqualTypeAndValue( + delegate.error.withLockedValue { $0 }, + HTTPClientError.invalidHeaderFieldValues(["not-fine\n"]) + ) + } + } + func testContentLengthTooLongFails() throws { let url = self.defaultHTTPBinURLPrefix + "post" + let defaultClient = self.defaultClient! XCTAssertThrowsError( - try self.defaultClient.execute(request: - Request(url: url, - body: .stream(length: 10) { streamWriter in - let promise = self.defaultClient.eventLoopGroup.next().makePromise(of: Void.self) + try self.defaultClient.execute( + request: + Request( + url: url, + body: .stream(contentLength: 10) { streamWriter in + let promise = defaultClient.eventLoopGroup.next().makePromise(of: Void.self) DispatchQueue(label: "content-length-test").async { streamWriter.write(.byteBuffer(ByteBuffer(string: "1"))).cascade(to: promise) } return promise.futureResult - })).wait()) { error in + } + ) + ).wait() + ) { error in XCTAssertEqual(error as! HTTPClientError, HTTPClientError.bodyLengthMismatch) } // Quickly try another request and check that it works. @@ -2596,11 +3382,16 @@ class HTTPClientTests: XCTestCase { let url = self.defaultHTTPBinURLPrefix + "post" let tooLong = "XBAD BAD BAD NOT HTTP/1.1\r\n\r\n" XCTAssertThrowsError( - try self.defaultClient.execute(request: - Request(url: url, - body: .stream(length: 1) { streamWriter in + try self.defaultClient.execute( + request: + Request( + url: url, + body: .stream(contentLength: 1) { streamWriter in streamWriter.write(.byteBuffer(ByteBuffer(string: tooLong))) - })).wait()) { error in + } + ) + ).wait() + ) { error in XCTAssertEqual(error as! HTTPClientError, HTTPClientError.bodyLengthMismatch) } // Quickly try another request and check that it works. If we by accident wrote some extra bytes into the @@ -2621,9 +3412,9 @@ class HTTPClientTests: XCTestCase { func testBodyUploadAfterEndFails() { let url = self.defaultHTTPBinURLPrefix + "post" - func uploader(_ streamWriter: HTTPClient.Body.StreamWriter) -> EventLoopFuture { + let uploader = { @Sendable (_ streamWriter: HTTPClient.Body.StreamWriter) -> EventLoopFuture in let done = streamWriter.write(.byteBuffer(ByteBuffer(string: "X"))) - done.recover { error -> Void in + done.recover { error in XCTFail("unexpected error \(error)") }.whenSuccess { // This is executed when we have already sent the end of the request. @@ -2642,7 +3433,7 @@ class HTTPClientTests: XCTestCase { } var request: HTTPClient.Request? - XCTAssertNoThrow(request = try Request(url: url, body: .stream(length: 1, uploader))) + XCTAssertNoThrow(request = try Request(url: url, body: .stream(contentLength: 1, uploader))) XCTAssertThrowsError(try self.defaultClient.execute(request: XCTUnwrap(request)).wait()) { XCTAssertEqual($0 as? HTTPClientError, .writeAfterRequestSent) } @@ -2652,59 +3443,12 @@ class HTTPClientTests: XCTestCase { XCTAssertNoThrow(try self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "get").wait()) } - func testNoBytesSentOverBodyLimit() throws { - let server = NIOHTTP1TestServer(group: self.serverGroup) - defer { - XCTAssertNoThrow(try server.stop()) - } - - let tooLong = "XBAD BAD BAD NOT HTTP/1.1\r\n\r\n" - - let request = try Request( - url: "http://localhost:\(server.serverPort)", - body: .stream(length: 1) { streamWriter in - streamWriter.write(.byteBuffer(ByteBuffer(string: tooLong))) - } - ) - - let future = self.defaultClient.execute(request: request) - - // Okay, what happens here needs an explanation: - // - // In the request state machine, we should start the request, which will lead to an - // invocation of `context.write(HTTPRequestHead)`. Since we will receive a streamed request - // body a `context.flush()` will be issued. Further the request stream will be started. - // Since the request stream immediately produces to much data, the request will be failed - // and the connection will be closed. - // - // Even though a flush was issued after the request head, there is no guarantee that the - // request head was written to the network. For this reason we must accept not receiving a - // request and receiving a request head. - - do { - _ = try server.receiveHead() - - // A request head was sent. We expect the request now to fail with a parsing error, - // since the client ended the connection to early (from the server's point of view.) - XCTAssertThrowsError(try server.readInbound()) { - XCTAssertEqual($0 as? HTTPParserError, HTTPParserError.invalidEOFState) - } - } catch { - // TBD: We sadly can't verify the error type, since it is private in `NIOTestUtils`: - // NIOTestUtils.BlockingQueue.TimeoutError - } - - // request must always be failed with this error - XCTAssertThrowsError(try future.wait()) { - XCTAssertEqual($0 as? HTTPClientError, .bodyLengthMismatch) - } - } - func testDoubleError() throws { // This is needed to that connection pool will not get into closed state when we release // second connection. _ = self.defaultClient.get(url: "http://localhost:\(self.defaultHTTPBin.port)/events/10/1") + let clientGroup = self.clientGroup! var request = try HTTPClient.Request(url: "http://localhost:\(self.defaultHTTPBin.port)/wait", method: .POST) request.body = .stream { writer in // Start writing chunks so tha we will try to write after read timeout is thrown @@ -2712,8 +3456,8 @@ class HTTPClientTests: XCTestCase { _ = writer.write(.byteBuffer(ByteBuffer(string: "1234"))) } - let promise = self.clientGroup.next().makePromise(of: Void.self) - self.clientGroup.next().scheduleTask(in: .milliseconds(3)) { + let promise = clientGroup.next().makePromise(of: Void.self) + clientGroup.next().scheduleTask(in: .milliseconds(3)) { writer.write(.byteBuffer(ByteBuffer(string: "1234"))).cascade(to: promise) } @@ -2722,11 +3466,13 @@ class HTTPClientTests: XCTestCase { // We specify a deadline of 2 ms co that request will be timed out before all chunks are writtent, // we need to verify that second error on write after timeout does not lead to double-release. - XCTAssertThrowsError(try self.defaultClient.execute(request: request, deadline: .now() + .milliseconds(2)).wait()) + XCTAssertThrowsError( + try self.defaultClient.execute(request: request, deadline: .now() + .milliseconds(2)).wait() + ) } func testSSLHandshakeErrorPropagation() throws { - class CloseHandler: ChannelInboundHandler { + final class CloseHandler: ChannelInboundHandler, Sendable { typealias InboundIn = Any func channelRead(context: ChannelHandlerContext, data: NIOAny) { @@ -2749,6 +3495,8 @@ class HTTPClientTests: XCTestCase { if isTestingNIOTS() { // If we are using Network.framework, we set the connect timeout down very low here // because on NIOTS a failing TLS handshake manifests as a connect timeout. + // Note that we do this here to prove that we correctly manifest the underlying error: + // DO NOT CHANGE THIS TO DISABLE WAITING FOR CONNECTIVITY. timeout.connect = .milliseconds(100) } @@ -2763,7 +3511,12 @@ class HTTPClientTests: XCTestCase { XCTAssertThrowsError(try task.wait()) { error in if isTestingNIOTS() { - XCTAssertEqual(error as? HTTPClientError, .connectTimeout) + #if canImport(Network) + // We can't be more specific than this. + XCTAssertTrue(error is HTTPClient.NWTLSError) + #else + XCTFail("Impossible condition") + #endif } else { switch error as? NIOSSLError { case .some(.handshakeFailed(.sslError(_))): break @@ -2776,11 +3529,11 @@ class HTTPClientTests: XCTestCase { func testSSLHandshakeErrorPropagationDelayedClose() throws { // This is as the test above, but the close handler delays its close action by a few hundred ms. // This will tend to catch the pipeline at different weird stages, and flush out different bugs. - class CloseHandler: ChannelInboundHandler { + final class CloseHandler: ChannelInboundHandler, Sendable { typealias InboundIn = Any func channelRead(context: ChannelHandlerContext, data: NIOAny) { - context.eventLoop.scheduleTask(in: .milliseconds(100)) { + context.eventLoop.assumeIsolated().scheduleTask(in: .milliseconds(100)) { context.close(promise: nil) } } @@ -2815,7 +3568,12 @@ class HTTPClientTests: XCTestCase { XCTAssertThrowsError(try task.wait()) { error in if isTestingNIOTS() { - XCTAssertEqual(error as? HTTPClientError, .connectTimeout) + #if canImport(Network) + // We can't be more specific than this. + XCTAssertTrue(error is HTTPClient.NWTLSError) + #else + XCTFail("Impossible condition") + #endif } else { switch error as? NIOSSLError { case .some(.handshakeFailed(.sslError(_))): break @@ -2832,8 +3590,8 @@ class HTTPClientTests: XCTestCase { let server = try ServerBootstrap(group: self.serverGroup) .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .childChannelInitializer { channel in - channel.pipeline.configureHTTPServerPipeline().flatMap { - channel.pipeline.addHandler(CloseWithoutClosingServerHandler(group.leave)) + channel.pipeline.configureHTTPServerPipeline().flatMapThrowing { + try channel.pipeline.syncOperations.addHandler(CloseWithoutClosingServerHandler(group.leave)) } } .bind(host: "localhost", port: 0) @@ -2872,7 +3630,7 @@ class HTTPClientTests: XCTestCase { let body: HTTPClient.Body = .stream { writer in let finalPromise = writeEL.makePromise(of: Void.self) - func writeLoop(_ writer: HTTPClient.Body.StreamWriter, index: Int) { + @Sendable func writeLoop(_ writer: HTTPClient.Body.StreamWriter, index: Int) { // always invoke from the wrong el to test thread safety writeEL.preconditionInEventLoop() @@ -2914,185 +3672,428 @@ class HTTPClientTests: XCTestCase { XCTAssertNil(try delegate.next().wait()) } - func testSynchronousHandshakeErrorReporting() throws { - // This only affects cases where we use NIOSSL. - guard !isTestingNIOTS() else { return } + func testResponseAccumulatorMaxBodySizeLimitExceedingWithContentLength() throws { + let httpBin = HTTPBin(.http1_1(ssl: false, compress: false)) { _ in HTTPEchoHandler() } + defer { XCTAssertNoThrow(try httpBin.shutdown()) } - // We use a specially crafted client that has no cipher suites to offer. To do this we ask - // only for cipher suites incompatible with our TLS version. - var tlsConfig = TLSConfiguration.makeClientConfiguration() - tlsConfig.minimumTLSVersion = .tlsv13 - tlsConfig.maximumTLSVersion = .tlsv12 - tlsConfig.certificateVerification = .none - let localHTTPBin = HTTPBin(.http1_1(ssl: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(tlsConfiguration: tlsConfig)) - defer { - XCTAssertNoThrow(try localClient.syncShutdown()) - XCTAssertNoThrow(try localHTTPBin.shutdown()) - } + let body = ByteBuffer(bytes: 0..<11) - XCTAssertThrowsError(try localClient.get(url: "https://localhost:\(localHTTPBin.port)/").wait()) { error in - guard let clientError = error as? NIOSSLError, case NIOSSLError.handshakeFailed = clientError else { - XCTFail("Unexpected error: \(error)") - return - } + var request = try Request(url: httpBin.baseURL) + request.body = .byteBuffer(body) + XCTAssertThrowsError( + try self.defaultClient.execute( + request: request, + delegate: ResponseAccumulator(request: request, maxBodySize: 10) + ).wait() + ) { error in + XCTAssertTrue(error is ResponseAccumulator.ResponseTooBigError, "unexpected error \(error)") } } - func testFileDownloadChunked() throws { - var request = try Request(url: self.defaultHTTPBinURLPrefix + "chunked") - request.headers.add(name: "Accept", value: "text/event-stream") - - let progress = - try TemporaryFileHelpers.withTemporaryFilePath { path -> FileDownloadDelegate.Progress in - let delegate = try FileDownloadDelegate(path: path) - - let progress = try self.defaultClient.execute( - request: request, - delegate: delegate - ) - .wait() + func testResponseAccumulatorMaxBodySizeLimitNotExceedingWithContentLength() throws { + let httpBin = HTTPBin(.http1_1(ssl: false, compress: false)) { _ in HTTPEchoHandler() } + defer { XCTAssertNoThrow(try httpBin.shutdown()) } - try XCTAssertEqual(50, TemporaryFileHelpers.fileSize(path: path)) + let body = ByteBuffer(bytes: 0..<10) - return progress - } + var request = try Request(url: httpBin.baseURL) + request.body = .byteBuffer(body) + let response = try self.defaultClient.execute( + request: request, + delegate: ResponseAccumulator(request: request, maxBodySize: 10) + ).wait() - XCTAssertEqual(nil, progress.totalBytes) - XCTAssertEqual(50, progress.receivedBytes) + XCTAssertEqual(response.body, body) } - func testCloseWhileBackpressureIsExertedIsFine() throws { - let request = try Request(url: self.defaultHTTPBinURLPrefix + "close-on-response") - let delegate = DelayOnHeadDelegate(eventLoop: self.clientGroup.next()) { _, promise in - promise.futureResult.eventLoop.scheduleTask(in: .milliseconds(50)) { - promise.succeed(()) - } - } + func testResponseAccumulatorMaxBodySizeLimitExceedingWithContentLengthButMethodIsHead() throws { + let httpBin = HTTPBin(.http1_1(ssl: false, compress: false)) { _ in HTTPEchoHeaders() } + defer { XCTAssertNoThrow(try httpBin.shutdown()) } - let resultFuture = self.defaultClient.execute(request: request, delegate: delegate) + let body = ByteBuffer(bytes: 0..<11) - // The full response must be correctly delivered. - var data = try resultFuture.wait() - guard let info = try data.readJSONDecodable(RequestInfo.self, length: data.readableBytes) else { - XCTFail("Could not parse response") - return - } - XCTAssertEqual(info.data, "some body content") - } + var request = try Request(url: httpBin.baseURL, method: .HEAD) + request.body = .byteBuffer(body) + let response = try self.defaultClient.execute( + request: request, + delegate: ResponseAccumulator(request: request, maxBodySize: 10) + ).wait() - func testErrorAfterCloseWhileBackpressureExerted() throws { - enum ExpectedError: Error { - case expected - } + XCTAssertEqual(response.body ?? ByteBuffer(), ByteBuffer()) + } - let request = try Request(url: self.defaultHTTPBinURLPrefix + "close-on-response") - let delegate = DelayOnHeadDelegate(eventLoop: self.clientGroup.next()) { _, backpressurePromise in - backpressurePromise.fail(ExpectedError.expected) - } + func testResponseAccumulatorMaxBodySizeLimitExceedingWithTransferEncodingChuncked() throws { + let httpBin = HTTPBin(.http1_1(ssl: false, compress: false)) { _ in HTTPEchoHandler() } + defer { XCTAssertNoThrow(try httpBin.shutdown()) } - let resultFuture = self.defaultClient.execute(request: request, delegate: delegate) + let body = ByteBuffer(bytes: 0..<11) - // The task must be failed. - XCTAssertThrowsError(try resultFuture.wait()) { error in - XCTAssertEqual(error as? ExpectedError, .expected) + var request = try Request(url: httpBin.baseURL) + request.body = .stream { writer in + writer.write(.byteBuffer(body)) + } + XCTAssertThrowsError( + try self.defaultClient.execute( + request: request, + delegate: ResponseAccumulator(request: request, maxBodySize: 10) + ).wait() + ) { error in + XCTAssertTrue(error is ResponseAccumulator.ResponseTooBigError, "unexpected error \(error)") } } - func testRequestSpecificTLS() throws { - let configuration = HTTPClient.Configuration(tlsConfiguration: nil, - timeout: .init(), - decompression: .disabled) - let localHTTPBin = HTTPBin(.http1_1(ssl: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: configuration) - let decoder = JSONDecoder() + func testResponseAccumulatorMaxBodySizeLimitNotExceedingWithTransferEncodingChuncked() throws { + let httpBin = HTTPBin(.http1_1(ssl: false, compress: false)) { _ in HTTPEchoHandler() } + defer { XCTAssertNoThrow(try httpBin.shutdown()) } - defer { - XCTAssertNoThrow(try localClient.syncShutdown()) - XCTAssertNoThrow(try localHTTPBin.shutdown()) - } + let body = ByteBuffer(bytes: 0..<10) - // First two requests use identical TLS configurations. - var tlsConfig = TLSConfiguration.makeClientConfiguration() - tlsConfig.certificateVerification = .none - let firstRequest = try HTTPClient.Request(url: "https://localhost:\(localHTTPBin.port)/get", method: .GET, tlsConfiguration: tlsConfig) - let firstResponse = try localClient.execute(request: firstRequest).wait() - guard let firstBody = firstResponse.body else { - XCTFail("No request body found") - return + var request = try Request(url: httpBin.baseURL) + request.body = .stream { writer in + writer.write(.byteBuffer(body)) } - let firstConnectionNumber = try decoder.decode(RequestInfo.self, from: firstBody).connectionNumber + let response = try self.defaultClient.execute( + request: request, + delegate: ResponseAccumulator(request: request, maxBodySize: 10) + ).wait() - let secondRequest = try HTTPClient.Request(url: "https://localhost:\(localHTTPBin.port)/get", method: .GET, tlsConfiguration: tlsConfig) - let secondResponse = try localClient.execute(request: secondRequest).wait() - guard let secondBody = secondResponse.body else { - XCTFail("No request body found") - return - } - let secondConnectionNumber = try decoder.decode(RequestInfo.self, from: secondBody).connectionNumber + XCTAssertEqual(response.body, body) + } - // Uses a differrent TLS config. - var tlsConfig2 = TLSConfiguration.makeClientConfiguration() - tlsConfig2.certificateVerification = .none - tlsConfig2.maximumTLSVersion = .tlsv1 - let thirdRequest = try HTTPClient.Request(url: "https://localhost:\(localHTTPBin.port)/get", method: .GET, tlsConfiguration: tlsConfig2) - let thirdResponse = try localClient.execute(request: thirdRequest).wait() - guard let thirdBody = thirdResponse.body else { - XCTFail("No request body found") - return + // In this test, we test that a request can continue to stream its body after the response head and end + // was received where the end is a 200. + func testBiDirectionalStreamingEarly200() { + let httpBin = HTTPBin(.http1_1(ssl: false, compress: false)) { _ in + HTTP200DelayedHandler(bodyPartsBeforeResponse: 1) } - let thirdConnectionNumber = try decoder.decode(RequestInfo.self, from: thirdBody).connectionNumber + defer { XCTAssertNoThrow(try httpBin.shutdown()) } - XCTAssertEqual(firstResponse.status, .ok) - XCTAssertEqual(secondResponse.status, .ok) - XCTAssertEqual(thirdResponse.status, .ok) - XCTAssertEqual(firstConnectionNumber, secondConnectionNumber, "Identical TLS configurations did not use the same connection") - XCTAssertNotEqual(thirdConnectionNumber, firstConnectionNumber, "Different TLS configurations did not use different connections.") - } + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let writeEL = eventLoopGroup.next() + let delegateEL = eventLoopGroup.next() - func testConnectionPoolSizeConfigValueIsRespected() { - let numberOfRequestsPerThread = 1000 - let numberOfParallelWorkers = 16 - let poolSize = 12 + let httpClient = HTTPClient(eventLoopGroupProvider: .shared(eventLoopGroup)) + defer { XCTAssertNoThrow(try httpClient.syncShutdown()) } - let httpBin = HTTPBin() - defer { XCTAssertNoThrow(try httpBin.shutdown()) } + let delegate = ResponseStreamDelegate(eventLoop: delegateEL) - let group = MultiThreadedEventLoopGroup(numberOfThreads: 4) - defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } + let body: HTTPClient.Body = .stream { writer in + let finalPromise = writeEL.makePromise(of: Void.self) - let configuration = HTTPClient.Configuration( - connectionPool: .init( - idleTimeout: .seconds(30), - concurrentHTTP1ConnectionsPerHostSoftLimit: poolSize - ) - ) - let client = HTTPClient(eventLoopGroupProvider: .shared(group), configuration: configuration) - defer { XCTAssertNoThrow(try client.syncShutdown()) } + @Sendable func writeLoop(_ writer: HTTPClient.Body.StreamWriter, index: Int) { + // always invoke from the wrong el to test thread safety + writeEL.preconditionInEventLoop() - let g = DispatchGroup() - for workerID in 0..= 30 { + return finalPromise.succeed(()) } - for _ in 0..= 30 { + return finalPromise.succeed(()) + } + + let sent = ByteBuffer(integer: index) + writer.write(.byteBuffer(sent)).whenComplete { result in + switch result { + case .success: + writeEL.execute { + writeLoop(writer, index: index + 1) + } + + case .failure(let error): + finalPromise.fail(error) + } + } + } + + writeEL.execute { + writeLoop(writer, index: 0) + } + + return finalPromise.futureResult + } + + let request = try! HTTPClient.Request(url: "http://localhost:\(httpBin.port)", body: body) + let future = httpClient.execute(request: request) + XCTAssertNoThrow(try future.wait()) + + // Try another request + let future2 = httpClient.execute(request: request) + XCTAssertNoThrow(try future2.wait()) + } + + // This test validates that we correctly close the connection after our body completes when we've streamed a + // body and received the 2XX response _before_ we finished our stream. + func testCloseConnectionAfterEarly2XXWhenStreaming() { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + + let onClosePromise = eventLoopGroup.next().makePromise(of: Void.self) + let httpBin = HTTPBin(.http1_1(ssl: false, compress: false)) { _ in + ExpectClosureServerHandler(onClosePromise: onClosePromise) + } + defer { XCTAssertNoThrow(try httpBin.shutdown()) } + + let writeEL = eventLoopGroup.next() + + let httpClient = HTTPClient(eventLoopGroupProvider: .shared(eventLoopGroup)) + defer { XCTAssertNoThrow(try httpClient.syncShutdown()) } + + let body: HTTPClient.Body = .stream { writer in + let finalPromise = writeEL.makePromise(of: Void.self) + + @Sendable func writeLoop(_ writer: HTTPClient.Body.StreamWriter, index: Int) { + // always invoke from the wrong el to test thread safety + writeEL.preconditionInEventLoop() + + if index >= 30 { + return finalPromise.succeed(()) + } + + let sent = ByteBuffer(integer: index) + writer.write(.byteBuffer(sent)).whenComplete { result in + switch result { + case .success: + writeEL.execute { + writeLoop(writer, index: index + 1) + } + + case .failure(let error): + finalPromise.fail(error) + } + } + } + + writeEL.execute { + writeLoop(writer, index: 0) + } + + return finalPromise.futureResult + } + + let headers = HTTPHeaders([("Connection", "close")]) + let request = try! HTTPClient.Request(url: "http://localhost:\(httpBin.port)", headers: headers, body: body) + let future = httpClient.execute(request: request) + XCTAssertNoThrow(try future.wait()) + XCTAssertNoThrow(try onClosePromise.futureResult.wait()) + } + + func testSynchronousHandshakeErrorReporting() throws { + // This only affects cases where we use NIOSSL. + guard !isTestingNIOTS() else { return } + + // We use a specially crafted client that has no cipher suites to offer. To do this we ask + // only for cipher suites incompatible with our TLS version. + var tlsConfig = TLSConfiguration.makeClientConfiguration() + tlsConfig.minimumTLSVersion = .tlsv13 + tlsConfig.maximumTLSVersion = .tlsv12 + tlsConfig.certificateVerification = .none + let localHTTPBin = HTTPBin(.http1_1(ssl: true)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(tlsConfiguration: tlsConfig) + ) + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + XCTAssertNoThrow(try localHTTPBin.shutdown()) + } + + XCTAssertThrowsError(try localClient.get(url: "https://localhost:\(localHTTPBin.port)/").wait()) { error in + guard let clientError = error as? NIOSSLError, case NIOSSLError.handshakeFailed = clientError else { + XCTFail("Unexpected error: \(error)") + return + } + } + } + + func testFileDownloadChunked() throws { + var request = try Request(url: self.defaultHTTPBinURLPrefix + "chunked") + request.headers.add(name: "Accept", value: "text/event-stream") + + let response = + try TemporaryFileHelpers.withTemporaryFilePath { path -> FileDownloadDelegate.Response in + let delegate = try FileDownloadDelegate(path: path) + + let response = try self.defaultClient.execute( + request: request, + delegate: delegate + ) + .wait() + + try XCTAssertEqual(50, TemporaryFileHelpers.fileSize(path: path)) + + return response + } + + XCTAssertEqual(.ok, response.head.status) + XCTAssertEqual("chunked", response.head.headers.first(name: "transfer-encoding")) + XCTAssertFalse(response.head.headers.contains(name: "content-length")) + + XCTAssertEqual(nil, response.totalBytes) + XCTAssertEqual(50, response.receivedBytes) + } + + func testCloseWhileBackpressureIsExertedIsFine() throws { + let request = try Request(url: self.defaultHTTPBinURLPrefix + "close-on-response") + let delegate = DelayOnHeadDelegate(eventLoop: self.clientGroup.next()) { _, promise in + promise.futureResult.eventLoop.scheduleTask(in: .milliseconds(50)) { + promise.succeed(()) + } + } + + let resultFuture = self.defaultClient.execute(request: request, delegate: delegate) + + // The full response must be correctly delivered. + var data = try resultFuture.wait() + guard let info = try data.readJSONDecodable(RequestInfo.self, length: data.readableBytes) else { + XCTFail("Could not parse response") + return + } + XCTAssertEqual(info.data, "some body content") + } + + func testErrorAfterCloseWhileBackpressureExerted() throws { + enum ExpectedError: Error { + case expected + } + + let request = try Request(url: self.defaultHTTPBinURLPrefix + "close-on-response") + let delegate = DelayOnHeadDelegate(eventLoop: self.clientGroup.next()) { _, backpressurePromise in + backpressurePromise.fail(ExpectedError.expected) + } + + let resultFuture = self.defaultClient.execute(request: request, delegate: delegate) + + // The task must be failed. + XCTAssertThrowsError(try resultFuture.wait()) { error in + XCTAssertEqual(error as? ExpectedError, .expected) + } + } + + func testRequestSpecificTLS() throws { + let configuration = HTTPClient.Configuration( + tlsConfiguration: nil, + timeout: .init(), + decompression: .disabled + ) + let localHTTPBin = HTTPBin(.http1_1(ssl: true)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: configuration + ) + let decoder = JSONDecoder() + + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + XCTAssertNoThrow(try localHTTPBin.shutdown()) + } + + // First two requests use identical TLS configurations. + var tlsConfig = TLSConfiguration.makeClientConfiguration() + tlsConfig.certificateVerification = .none + let firstRequest = try HTTPClient.Request( + url: "https://localhost:\(localHTTPBin.port)/get", + method: .GET, + tlsConfiguration: tlsConfig + ) + let firstResponse = try localClient.execute(request: firstRequest).wait() + guard let firstBody = firstResponse.body else { + XCTFail("No request body found") + return + } + let firstConnectionNumber = try decoder.decode(RequestInfo.self, from: firstBody).connectionNumber + + let secondRequest = try HTTPClient.Request( + url: "https://localhost:\(localHTTPBin.port)/get", + method: .GET, + tlsConfiguration: tlsConfig + ) + let secondResponse = try localClient.execute(request: secondRequest).wait() + guard let secondBody = secondResponse.body else { + XCTFail("No request body found") + return + } + let secondConnectionNumber = try decoder.decode(RequestInfo.self, from: secondBody).connectionNumber + + // Uses a differrent TLS config. + var tlsConfig2 = TLSConfiguration.makeClientConfiguration() + tlsConfig2.certificateVerification = .none + tlsConfig2.maximumTLSVersion = .tlsv1 + let thirdRequest = try HTTPClient.Request( + url: "https://localhost:\(localHTTPBin.port)/get", + method: .GET, + tlsConfiguration: tlsConfig2 + ) + let thirdResponse = try localClient.execute(request: thirdRequest).wait() + guard let thirdBody = thirdResponse.body else { + XCTFail("No request body found") + return + } + let thirdConnectionNumber = try decoder.decode(RequestInfo.self, from: thirdBody).connectionNumber + + XCTAssertEqual(firstResponse.status, .ok) + XCTAssertEqual(secondResponse.status, .ok) + XCTAssertEqual(thirdResponse.status, .ok) + XCTAssertEqual( + firstConnectionNumber, + secondConnectionNumber, + "Identical TLS configurations did not use the same connection" + ) + XCTAssertNotEqual( + thirdConnectionNumber, + firstConnectionNumber, + "Different TLS configurations did not use different connections." + ) } func testRequestWithHeaderTransferEncodingIdentityDoesNotFail() { @@ -3114,4 +4115,483 @@ class HTTPClientTests: XCTestCase { XCTAssertNoThrow(try client.execute(request: request).wait()) } + + func testMassiveDownload() { + var response: HTTPClient.Response? + XCTAssertNoThrow( + response = try self.defaultClient.get(url: "\(self.defaultHTTPBinURLPrefix)mega-chunked").wait() + ) + + XCTAssertEqual(.ok, response?.status) + XCTAssertEqual(response?.version, .http1_1) + XCTAssertEqual(response?.body?.readableBytes, 10_000) + } + + func testShutdownWithFutures() { + let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup)) + XCTAssertNoThrow(try httpClient.shutdown().wait()) + } + + func testMassiveHeaderHTTP1() throws { + var request = try HTTPClient.Request(url: defaultHTTPBin.baseURL, method: .POST) + // add ~64 KB header + let headerValue = String(repeating: "0", count: 1024) + for headerID in 0..<64 { + request.headers.replaceOrAdd(name: "larg-header-\(headerID)", value: headerValue) + } + + // non empty body is important to trigger this bug as we otherwise finish the request in a single flush + request.body = .byteBuffer(ByteBuffer(bytes: [0])) + + XCTAssertNoThrow(try defaultClient.execute(request: request).wait()) + } + + func testMassiveHeaderHTTP2() throws { + let bin = HTTPBin( + .http2(settings: [ + .init(parameter: .maxConcurrentStreams, value: 100), + .init(parameter: .maxHeaderListSize, value: 1024 * 256), + .init(parameter: .maxFrameSize, value: 1024 * 256), + ]) + ) + defer { XCTAssertNoThrow(try bin.shutdown()) } + + let client = HTTPClient( + eventLoopGroupProvider: .shared(clientGroup), + configuration: .init(certificateVerification: .none) + ) + + defer { XCTAssertNoThrow(try client.syncShutdown()) } + + var request = try HTTPClient.Request(url: bin.baseURL, method: .POST) + // add ~200 KB header + let headerValue = String(repeating: "0", count: 1024) + for headerID in 0..<200 { + request.headers.replaceOrAdd(name: "larg-header-\(headerID)", value: headerValue) + } + + // non empty body is important to trigger this bug as we otherwise finish the request in a single flush + request.body = .byteBuffer(ByteBuffer(bytes: [0])) + + XCTAssertNoThrow(try client.execute(request: request).wait()) + } + + func testCancelingRequestAfterRedirect() throws { + let request = try Request( + url: self.defaultHTTPBinURLPrefix + "redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": self.defaultHTTPBinURLPrefix + "wait"], + body: nil + ) + + final class CancelAfterRedirect: HTTPClientResponseDelegate, Sendable { + init() {} + func didFinishRequest(task: AsyncHTTPClient.HTTPClient.Task) throws {} + } + + let task = defaultClient.execute( + request: request, + delegate: CancelAfterRedirect(), + deadline: .now() + .seconds(1) + ) + + // there is currently no HTTPClientResponseDelegate method to ensure the redirect occurs before we cancel, so we just sleep for 500ms + Thread.sleep(forTimeInterval: 0.5) + + task.cancel() + + XCTAssertThrowsError(try task.wait()) { error in + guard case let error = error as? HTTPClientError, error == .cancelled else { + return XCTFail("Should fail with cancelled") + } + } + } + + func testFailingRequestAfterRedirect() throws { + let request = try Request( + url: self.defaultHTTPBinURLPrefix + "redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": self.defaultHTTPBinURLPrefix + "wait"], + body: nil + ) + + final class FailAfterRedirect: HTTPClientResponseDelegate, Sendable { + init() {} + func didFinishRequest(task: AsyncHTTPClient.HTTPClient.Task) throws {} + } + + let task = defaultClient.execute( + request: request, + delegate: FailAfterRedirect(), + deadline: .now() + .seconds(1) + ) + + // there is currently no HTTPClientResponseDelegate method to ensure the redirect occurs before we fail, so we just sleep for 500ms + Thread.sleep(forTimeInterval: 0.5) + + struct TestError: Error {} + + task.fail(reason: TestError()) + + XCTAssertThrowsError(try task.wait()) { error in + guard error is TestError else { + return XCTFail("Should fail with TestError") + } + } + } + + func testCancelingHTTP1RequestAfterHeaderSend() throws { + var request = try HTTPClient.Request(url: self.defaultHTTPBin.baseURL + "/wait", method: .POST) + // non-empty body is important + request.body = .byteBuffer(ByteBuffer([1])) + + final class CancelAfterHeadSend: HTTPClientResponseDelegate, Sendable { + init() {} + func didFinishRequest(task: AsyncHTTPClient.HTTPClient.Task) throws {} + func didSendRequestHead(task: HTTPClient.Task, _ head: HTTPRequestHead) { + task.cancel() + } + } + XCTAssertThrowsError(try defaultClient.execute(request: request, delegate: CancelAfterHeadSend()).wait()) + } + + func testCancelingHTTP2RequestAfterHeaderSend() throws { + let bin = HTTPBin(.http2()) + defer { XCTAssertNoThrow(try bin.shutdown()) } + var request = try HTTPClient.Request(url: bin.baseURL + "/wait", method: .POST) + // non-empty body is important + request.body = .byteBuffer(ByteBuffer([1])) + + final class CancelAfterHeadSend: HTTPClientResponseDelegate, Sendable { + init() {} + func didFinishRequest(task: AsyncHTTPClient.HTTPClient.Task) throws {} + func didSendRequestHead(task: HTTPClient.Task, _ head: HTTPRequestHead) { + task.cancel() + } + } + XCTAssertThrowsError(try defaultClient.execute(request: request, delegate: CancelAfterHeadSend()).wait()) + } + + private func testMaxConnectionReuses(mode: HTTPBin.Mode, maximumUses: Int, requests: Int) throws { + let bin = HTTPBin(mode) + defer { XCTAssertNoThrow(try bin.shutdown()) } + + var configuration = HTTPClient.Configuration(certificateVerification: .none) + // Limit each connection to two uses before discarding them. The test will verify that the + // connection number indicated by the server increments every two requests. + configuration.maximumUsesPerConnection = maximumUses + + let client = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: configuration) + defer { XCTAssertNoThrow(try client.syncShutdown()) } + + let request = try HTTPClient.Request(url: bin.baseURL + "stats") + let decoder = JSONDecoder() + + // Do two requests per batch. Both should report the same connection number. + for requestNumber in stride(from: 0, to: requests, by: maximumUses) { + var responses = [RequestInfo]() + + for _ in 0..(0) + var executionCount: Int { self._executionCount.withLockedValue { $0 } } + + /// The minimum time to spend running the debug initializer. + static let duration: TimeAmount = .milliseconds(300) + + /// The actual debug initializer. + func initialize(channel: Channel) -> EventLoopFuture { + self._executionCount.withLockedValue { $0 += 1 } + + let someScheduledTask = channel.eventLoop.scheduleTask(in: Self.duration) { + channel.eventLoop.makeSucceededVoidFuture() + } + + return someScheduledTask.futureResult.flatMap { $0 } + } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTracingInternalTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTracingInternalTests.swift new file mode 100644 index 000000000..53f1138ba --- /dev/null +++ b/Tests/AsyncHTTPClientTests/HTTPClientTracingInternalTests.swift @@ -0,0 +1,77 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2025 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Atomics +import InMemoryTracing +import Logging +import NIOConcurrencyHelpers +import NIOCore +import NIOEmbedded +import NIOFoundationCompat +import NIOHTTP1 +import NIOHTTPCompression +import NIOPosix +import NIOSSL +import NIOTestUtils +import NIOTransportServices +import Tracing +import XCTest + +@testable @_spi(Tracing) import AsyncHTTPClient + +#if canImport(Network) +import Network +#endif + +private func makeTracedHTTPClient(tracer: InMemoryTracer) -> HTTPClient { + var config = HTTPClient.Configuration() + config.httpVersion = .automatic + config.tracing.tracer = tracer + return HTTPClient( + eventLoopGroupProvider: .singleton, + configuration: config + ) +} + +final class HTTPClientTracingInternalTests: XCTestCaseHTTPClientTestsBaseClass { + + var tracer: InMemoryTracer! + var client: HTTPClient! + + override func setUp() { + super.setUp() + self.tracer = InMemoryTracer() + self.client = makeTracedHTTPClient(tracer: tracer) + } + + override func tearDown() { + if let client = self.client { + XCTAssertNoThrow(try client.syncShutdown()) + self.client = nil + } + tracer = nil + } + + func testTrace_preparedHeaders_include_fromSpan() async throws { + let url = self.defaultHTTPBinURLPrefix + "404-does-not-exist" + let request = HTTPClientRequest(url: url) + + try tracer.withSpan("operation") { span in + let prepared = try HTTPClientRequest.Prepared(request, tracing: self.client.tracing) + XCTAssertTrue(prepared.head.headers.count > 2) + XCTAssertTrue(prepared.head.headers.contains(name: "in-memory-trace-id")) + XCTAssertTrue(prepared.head.headers.contains(name: "in-memory-span-id")) + } + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTracingTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTracingTests.swift new file mode 100644 index 000000000..047c66e6d --- /dev/null +++ b/Tests/AsyncHTTPClientTests/HTTPClientTracingTests.swift @@ -0,0 +1,150 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2025 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +@_spi(Tracing) import AsyncHTTPClient // NOT @testable - tests that need @testable go into HTTPClientTracingInternalTests.swift +import Atomics +import InMemoryTracing +import Logging +import NIOConcurrencyHelpers +import NIOCore +import NIOEmbedded +import NIOFoundationCompat +import NIOHTTP1 +import NIOHTTPCompression +import NIOPosix +import NIOSSL +import NIOTestUtils +import NIOTransportServices +import Tracing +import XCTest + +#if canImport(Network) +import Network +#endif + +private func makeTracedHTTPClient(tracer: InMemoryTracer) -> HTTPClient { + var config = HTTPClient.Configuration() + config.httpVersion = .automatic + config.tracing.tracer = tracer + return HTTPClient( + eventLoopGroupProvider: .singleton, + configuration: config + ) +} + +final class HTTPClientTracingTests: XCTestCaseHTTPClientTestsBaseClass { + + var tracer: InMemoryTracer! + var client: HTTPClient! + + override func setUp() { + super.setUp() + self.tracer = InMemoryTracer() + self.client = makeTracedHTTPClient(tracer: tracer) + } + + override func tearDown() { + if let client = self.client { + XCTAssertNoThrow(try client.syncShutdown()) + self.client = nil + } + tracer = nil + } + + func testTrace_get_sync() throws { + let url = self.defaultHTTPBinURLPrefix + "echo-method" + let _ = try client.get(url: url).wait() + + guard tracer.activeSpans.isEmpty else { + XCTFail("Still active spans which were not finished (\(tracer.activeSpans.count))! \(tracer.activeSpans)") + return + } + guard let span = tracer.finishedSpans.first else { + XCTFail("No span was recorded!") + return + } + + XCTAssertEqual(span.operationName, "GET") + } + + func testTrace_post_sync() throws { + let url = self.defaultHTTPBinURLPrefix + "echo-method" + let _ = try client.post(url: url).wait() + + guard tracer.activeSpans.isEmpty else { + XCTFail("Still active spans which were not finished (\(tracer.activeSpans.count))! \(tracer.activeSpans)") + return + } + guard let span = tracer.finishedSpans.first else { + XCTFail("No span was recorded!") + return + } + + XCTAssertEqual(span.operationName, "POST") + } + + func testTrace_post_sync_404_error() throws { + let url = self.defaultHTTPBinURLPrefix + "404-not-existent" + let _ = try client.post(url: url).wait() + + guard tracer.activeSpans.isEmpty else { + XCTFail("Still active spans which were not finished (\(tracer.activeSpans.count))! \(tracer.activeSpans)") + return + } + guard let span = tracer.finishedSpans.first else { + XCTFail("No span was recorded!") + return + } + + XCTAssertEqual(span.operationName, "POST") + XCTAssertTrue(span.errors.isEmpty, "Should have recorded error") + XCTAssertEqual(span.attributes.get(client.tracing.attributeKeys.responseStatusCode), 404) + } + + func testTrace_execute_async() async throws { + let url = self.defaultHTTPBinURLPrefix + "echo-method" + let request = HTTPClientRequest(url: url) + let _ = try await client.execute(request, deadline: .distantFuture) + + guard tracer.activeSpans.isEmpty else { + XCTFail("Still active spans which were not finished (\(tracer.activeSpans.count))! \(tracer.activeSpans)") + return + } + guard let span = tracer.finishedSpans.first else { + XCTFail("No span was recorded!") + return + } + + XCTAssertEqual(span.operationName, "GET") + } + + func testTrace_execute_async_404_error() async throws { + let url = self.defaultHTTPBinURLPrefix + "404-does-not-exist" + let request = HTTPClientRequest(url: url) + let _ = try await client.execute(request, deadline: .distantFuture) + + guard tracer.activeSpans.isEmpty else { + XCTFail("Still active spans which were not finished (\(tracer.activeSpans.count))! \(tracer.activeSpans)") + return + } + guard let span = tracer.finishedSpans.first else { + XCTFail("No span was recorded!") + return + } + + XCTAssertEqual(span.operationName, "GET") + XCTAssertTrue(span.errors.isEmpty, "Should have recorded error") + XCTAssertEqual(span.attributes.get(client.tracing.attributeKeys.responseStatusCode), 404) + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientUncleanSSLConnectionShutdownTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientUncleanSSLConnectionShutdownTests+XCTest.swift deleted file mode 100644 index d95346673..000000000 --- a/Tests/AsyncHTTPClientTests/HTTPClientUncleanSSLConnectionShutdownTests+XCTest.swift +++ /dev/null @@ -1,36 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTPClientUncleanSSLConnectionShutdownTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTPClientUncleanSSLConnectionShutdownTests { - static var allTests: [(String, (HTTPClientUncleanSSLConnectionShutdownTests) -> () throws -> Void)] { - return [ - ("testEOFFramedSuccess", testEOFFramedSuccess), - ("testContentLength", testContentLength), - ("testContentLengthButTruncated", testContentLengthButTruncated), - ("testTransferEncoding", testTransferEncoding), - ("testTransferEncodingButTruncated", testTransferEncodingButTruncated), - ("testConnectionDrop", testConnectionDrop), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientUncleanSSLConnectionShutdownTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientUncleanSSLConnectionShutdownTests.swift index 854d9092c..b63eb7cba 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientUncleanSSLConnectionShutdownTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientUncleanSSLConnectionShutdownTests.swift @@ -155,7 +155,8 @@ final class HTTPClientUncleanSSLConnectionShutdownTests: XCTestCase { ) defer { XCTAssertNoThrow(try client.syncShutdown()) } - XCTAssertThrowsError(try client.get(url: "https://localhost:\(httpBin.port)/transferencodingtruncated").wait()) { + XCTAssertThrowsError(try client.get(url: "https://localhost:\(httpBin.port)/transferencodingtruncated").wait()) + { XCTAssertEqual($0 as? HTTPParserError, .invalidEOFState) } } @@ -184,7 +185,7 @@ final class HTTPBinForSSLUncleanShutdown { let serverChannel: Channel var port: Int { - return Int(self.serverChannel.localAddress!.port!) + Int(self.serverChannel.localAddress!.port!) } init() { @@ -231,61 +232,61 @@ private final class HTTPBinForSSLUncleanShutdownHandler: ChannelInboundHandler { switch req.uri { case "/nocontentlength": response = """ - HTTP/1.1 200 OK\r\n\ - Connection: close\r\n\ - \r\n\ - foo - """ + HTTP/1.1 200 OK\r\n\ + Connection: close\r\n\ + \r\n\ + foo + """ case "/nocontent": response = """ - HTTP/1.1 204 OK\r\n\ - Connection: close\r\n\ - \r\n - """ + HTTP/1.1 204 OK\r\n\ + Connection: close\r\n\ + \r\n + """ case "/noresponse": response = nil case "/wrongcontentlength": response = """ - HTTP/1.1 200 OK\r\n\ - Connection: close\r\n\ - Content-Length: 6\r\n\ - \r\n\ - foo - """ + HTTP/1.1 200 OK\r\n\ + Connection: close\r\n\ + Content-Length: 6\r\n\ + \r\n\ + foo + """ case "/transferencoding": response = """ - HTTP/1.1 200 OK\r\n\ - Connection: close\r\n\ - Transfer-Encoding: chunked\r\n\ - \r\n\ - 3\r\n\ - foo\r\n\ - 0\r\n\ - \r\n - """ + HTTP/1.1 200 OK\r\n\ + Connection: close\r\n\ + Transfer-Encoding: chunked\r\n\ + \r\n\ + 3\r\n\ + foo\r\n\ + 0\r\n\ + \r\n + """ case "/transferencodingtruncated": response = """ - HTTP/1.1 200 OK\r\n\ - Connection: close\r\n\ - Transfer-Encoding: chunked\r\n\ - \r\n\ - 12\r\n\ - foo - """ + HTTP/1.1 200 OK\r\n\ + Connection: close\r\n\ + Transfer-Encoding: chunked\r\n\ + \r\n\ + 12\r\n\ + foo + """ default: response = """ - HTTP/1.1 404 OK\r\n\ - Connection: close\r\n\ - Content-Length: 9\r\n\ - \r\n\ - Not Found - """ + HTTP/1.1 404 OK\r\n\ + Connection: close\r\n\ + Content-Length: 9\r\n\ + \r\n\ + Not Found + """ } if let response = response { diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+FactoryTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+FactoryTests+XCTest.swift deleted file mode 100644 index 898b2b867..000000000 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+FactoryTests+XCTest.swift +++ /dev/null @@ -1,34 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTPConnectionPool+FactoryTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTPConnectionPool_FactoryTests { - static var allTests: [(String, (HTTPConnectionPool_FactoryTests) -> () throws -> Void)] { - return [ - ("testConnectionCreationTimesoutIfDeadlineIsInThePast", testConnectionCreationTimesoutIfDeadlineIsInThePast), - ("testSOCKSConnectionCreationTimesoutIfRemoteIsUnresponsive", testSOCKSConnectionCreationTimesoutIfRemoteIsUnresponsive), - ("testHTTPProxyConnectionCreationTimesoutIfRemoteIsUnresponsive", testHTTPProxyConnectionCreationTimesoutIfRemoteIsUnresponsive), - ("testTLSConnectionCreationTimesoutIfRemoteIsUnresponsive", testTLSConnectionCreationTimesoutIfRemoteIsUnresponsive), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+FactoryTests.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+FactoryTests.swift index b13ff3d18..37ff3a1ef 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+FactoryTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+FactoryTests.swift @@ -12,7 +12,6 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOCore import NIOPosix @@ -20,18 +19,22 @@ import NIOSOCKS import NIOSSL import XCTest +@testable import AsyncHTTPClient + class HTTPConnectionPool_FactoryTests: XCTestCase { func testConnectionCreationTimesoutIfDeadlineIsInThePast() { let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } var server: Channel? - XCTAssertNoThrow(server = try ServerBootstrap(group: group) - .childChannelInitializer { channel in - channel.pipeline.addHandler(NeverrespondServerHandler()) - } - .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) - .wait()) + XCTAssertNoThrow( + server = try ServerBootstrap(group: group) + .childChannelInitializer { channel in + channel.pipeline.addHandler(NeverrespondServerHandler()) + } + .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) + .wait() + ) defer { XCTAssertNoThrow(try server?.close().wait()) } @@ -45,14 +48,19 @@ class HTTPConnectionPool_FactoryTests: XCTestCase { sslContextCache: .init() ) - XCTAssertThrowsError(try factory.makeChannel( - connectionID: 1, - deadline: .now() - .seconds(1), - eventLoop: group.next(), - logger: .init(label: "test") - ).wait() + XCTAssertThrowsError( + try factory.makeChannel( + requester: ExplodingRequester(), + connectionID: 1, + deadline: .now() - .seconds(1), + eventLoop: group.next(), + logger: .init(label: "test") + ).wait() ) { - XCTAssertEqual($0 as? HTTPClientError, .connectTimeout) + guard let error = $0 as? ChannelError, case .connectTimeout = error else { + XCTFail("Unexpected error: \($0)") + return + } } } @@ -61,12 +69,14 @@ class HTTPConnectionPool_FactoryTests: XCTestCase { defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } var server: Channel? - XCTAssertNoThrow(server = try ServerBootstrap(group: group) - .childChannelInitializer { channel in - channel.pipeline.addHandler(NeverrespondServerHandler()) - } - .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) - .wait()) + XCTAssertNoThrow( + server = try ServerBootstrap(group: group) + .childChannelInitializer { channel in + channel.pipeline.addHandler(NeverrespondServerHandler()) + } + .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) + .wait() + ) defer { XCTAssertNoThrow(try server?.close().wait()) } @@ -76,16 +86,19 @@ class HTTPConnectionPool_FactoryTests: XCTestCase { let factory = HTTPConnectionPool.ConnectionFactory( key: .init(request), tlsConfiguration: nil, - clientConfiguration: .init(proxy: .socksServer(host: "127.0.0.1", port: server!.localAddress!.port!)), + clientConfiguration: .init(proxy: .socksServer(host: "127.0.0.1", port: server!.localAddress!.port!)) + .enableFastFailureModeForTesting(), sslContextCache: .init() ) - XCTAssertThrowsError(try factory.makeChannel( - connectionID: 1, - deadline: .now() + .seconds(1), - eventLoop: group.next(), - logger: .init(label: "test") - ).wait() + XCTAssertThrowsError( + try factory.makeChannel( + requester: ExplodingRequester(), + connectionID: 1, + deadline: .now() + .seconds(1), + eventLoop: group.next(), + logger: .init(label: "test") + ).wait() ) { XCTAssertEqual($0 as? HTTPClientError, .socksHandshakeTimeout) } @@ -96,12 +109,14 @@ class HTTPConnectionPool_FactoryTests: XCTestCase { defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } var server: Channel? - XCTAssertNoThrow(server = try ServerBootstrap(group: group) - .childChannelInitializer { channel in - channel.pipeline.addHandler(NeverrespondServerHandler()) - } - .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) - .wait()) + XCTAssertNoThrow( + server = try ServerBootstrap(group: group) + .childChannelInitializer { channel in + channel.pipeline.addHandler(NeverrespondServerHandler()) + } + .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) + .wait() + ) defer { XCTAssertNoThrow(try server?.close().wait()) } @@ -111,16 +126,19 @@ class HTTPConnectionPool_FactoryTests: XCTestCase { let factory = HTTPConnectionPool.ConnectionFactory( key: .init(request), tlsConfiguration: nil, - clientConfiguration: .init(proxy: .server(host: "127.0.0.1", port: server!.localAddress!.port!)), + clientConfiguration: .init(proxy: .server(host: "127.0.0.1", port: server!.localAddress!.port!)) + .enableFastFailureModeForTesting(), sslContextCache: .init() ) - XCTAssertThrowsError(try factory.makeChannel( - connectionID: 1, - deadline: .now() + .seconds(1), - eventLoop: group.next(), - logger: .init(label: "test") - ).wait() + XCTAssertThrowsError( + try factory.makeChannel( + requester: ExplodingRequester(), + connectionID: 1, + deadline: .now() + .seconds(1), + eventLoop: group.next(), + logger: .init(label: "test") + ).wait() ) { XCTAssertEqual($0 as? HTTPClientError, .httpProxyHandshakeTimeout) } @@ -131,12 +149,14 @@ class HTTPConnectionPool_FactoryTests: XCTestCase { defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } var server: Channel? - XCTAssertNoThrow(server = try ServerBootstrap(group: group) - .childChannelInitializer { channel in - channel.pipeline.addHandler(NeverrespondServerHandler()) - } - .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) - .wait()) + XCTAssertNoThrow( + server = try ServerBootstrap(group: group) + .childChannelInitializer { channel in + channel.pipeline.addHandler(NeverrespondServerHandler()) + } + .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) + .wait() + ) defer { XCTAssertNoThrow(try server?.close().wait()) } @@ -148,26 +168,69 @@ class HTTPConnectionPool_FactoryTests: XCTestCase { let factory = HTTPConnectionPool.ConnectionFactory( key: .init(request), tlsConfiguration: nil, - clientConfiguration: .init(tlsConfiguration: tlsConfig), + clientConfiguration: .init(tlsConfiguration: tlsConfig) + .enableFastFailureModeForTesting(), sslContextCache: .init() ) - XCTAssertThrowsError(try factory.makeChannel( - connectionID: 1, - deadline: .now() + .seconds(1), - eventLoop: group.next(), - logger: .init(label: "test") - ).wait() + XCTAssertThrowsError( + try factory.makeChannel( + requester: ExplodingRequester(), + connectionID: 1, + deadline: .now() + .seconds(1), + eventLoop: group.next(), + logger: .init(label: "test") + ).wait() ) { XCTAssertEqual($0 as? HTTPClientError, .tlsHandshakeTimeout) } } } -class NeverrespondServerHandler: ChannelInboundHandler { +final class NeverrespondServerHandler: ChannelInboundHandler, Sendable { typealias InboundIn = NIOAny func channelRead(context: ChannelHandlerContext, data: NIOAny) { // do nothing } } + +/// A `HTTPConnectionRequester` that will fail a test if any of its methods are ever called. +final class ExplodingRequester: HTTPConnectionRequester { + func http1ConnectionCreated(_: HTTP1Connection.SendableView) { + XCTFail("http1ConnectionCreated called unexpectedly") + } + + func http2ConnectionCreated(_: HTTP2Connection.SendableView, maximumStreams: Int) { + XCTFail("http2ConnectionCreated called unexpectedly") + } + + func failedToCreateHTTPConnection(_: HTTPConnectionPool.Connection.ID, error: Error) { + XCTFail("failedToCreateHTTPConnection called unexpectedly") + } + + func waitingForConnectivity(_: HTTPConnectionPool.Connection.ID, error: Error) { + XCTFail("waitingForConnectivity called unexpectedly") + } +} + +extension HTTPConnectionPool.ConnectionFactory { + fileprivate func makeChannel( + requester: Requester, + connectionID: HTTPConnectionPool.Connection.ID, + deadline: NIODeadline, + eventLoop: EventLoop, + logger: Logger + ) -> EventLoopFuture { + let promise = eventLoop.makePromise(of: NegotiatedProtocol.self) + self.makeChannel( + requester: requester, + connectionID: connectionID, + deadline: deadline, + eventLoop: eventLoop, + logger: logger, + promise: promise + ) + return promise.futureResult + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1ConnectionsTest+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1ConnectionsTest+XCTest.swift deleted file mode 100644 index 21eb3029e..000000000 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1ConnectionsTest+XCTest.swift +++ /dev/null @@ -1,47 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTPConnectionPool+HTTP1ConnectionsTest+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTPConnectionPool_HTTP1ConnectionsTests { - static var allTests: [(String, (HTTPConnectionPool_HTTP1ConnectionsTests) -> () throws -> Void)] { - return [ - ("testCreatingConnections", testCreatingConnections), - ("testCreatingConnectionAndFailing", testCreatingConnectionAndFailing), - ("testLeaseConnectionOnPreferredAndAvailableEL", testLeaseConnectionOnPreferredAndAvailableEL), - ("testLeaseConnectionOnPreferredButUnavailableEL", testLeaseConnectionOnPreferredButUnavailableEL), - ("testLeaseConnectionOnRequiredButUnavailableEL", testLeaseConnectionOnRequiredButUnavailableEL), - ("testLeaseConnectionOnRequiredAndAvailableEL", testLeaseConnectionOnRequiredAndAvailableEL), - ("testCloseConnectionIfIdle", testCloseConnectionIfIdle), - ("testCloseConnectionIfIdleButLeasedRaceCondition", testCloseConnectionIfIdleButLeasedRaceCondition), - ("testCloseConnectionIfIdleButClosedRaceCondition", testCloseConnectionIfIdleButClosedRaceCondition), - ("testShutdown", testShutdown), - ("testMigrationFromHTTP2", testMigrationFromHTTP2), - ("testMigrationFromHTTP2WithPendingRequestsWithRequiredEventLoop", testMigrationFromHTTP2WithPendingRequestsWithRequiredEventLoop), - ("testMigrationFromHTTP2WithPendingRequestsWithPreferredEventLoop", testMigrationFromHTTP2WithPendingRequestsWithPreferredEventLoop), - ("testMigrationFromHTTP2WithAlreadyLeasedHTTP1Connection", testMigrationFromHTTP2WithAlreadyLeasedHTTP1Connection), - ("testMigrationFromHTTP2WithMoreStartingConnectionsThanMaximumAllowedConccurentConnections", testMigrationFromHTTP2WithMoreStartingConnectionsThanMaximumAllowedConccurentConnections), - ("testMigrationFromHTTP2StartsEnoghOverflowConnectionsForRequiredEventLoopRequests", testMigrationFromHTTP2StartsEnoghOverflowConnectionsForRequiredEventLoopRequests), - ("testMigrationFromHTTP1ToHTTP2AndBackToHTTP1", testMigrationFromHTTP1ToHTTP2AndBackToHTTP1), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1ConnectionsTest.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1ConnectionsTest.swift index 5afe755a1..89f3bf7b5 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1ConnectionsTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1ConnectionsTest.swift @@ -12,15 +12,21 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOEmbedded import XCTest +@testable import AsyncHTTPClient + class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { func testCreatingConnections() { let elg = EmbeddedEventLoopGroup(loops: 4) - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: .init()) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: .init(), + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) let el1 = elg.next() let el2 = elg.next() @@ -52,7 +58,12 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { func testCreatingConnectionAndFailing() { let elg = EmbeddedEventLoopGroup(loops: 4) - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: .init()) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: .init(), + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) let el1 = elg.next() let el2 = elg.next() @@ -103,7 +114,12 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { let el3 = elg.next() let el4 = elg.next() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: .init()) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: .init(), + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) for el in [el1, el2, el3, el4] { XCTAssertEqual(connections.startingGeneralPurposeConnections, 0) @@ -130,7 +146,12 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { let el4 = elg.next() let el5 = elg.next() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: .init()) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: .init(), + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) for el in [el1, el2, el3, el4] { XCTAssertEqual(connections.startingGeneralPurposeConnections, 0) @@ -157,7 +178,12 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { let el4 = elg.next() let el5 = elg.next() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: .init()) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: .init(), + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) for el in [el1, el2, el3, el4] { XCTAssertEqual(connections.startingGeneralPurposeConnections, 0) @@ -181,7 +207,12 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { let el1 = elg.next() let el2 = elg.next() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: .init()) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: .init(), + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) for el in [el1, el1, el1, el1, el2] { let connID = connections.createNewConnection(on: el) @@ -228,7 +259,12 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { func testCloseConnectionIfIdle() { let elg = EmbeddedEventLoopGroup(loops: 1) - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: .init()) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: .init(), + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) let el1 = elg.next() @@ -248,7 +284,12 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { func testCloseConnectionIfIdleButLeasedRaceCondition() { let elg = EmbeddedEventLoopGroup(loops: 1) - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: .init()) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: .init(), + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) let el1 = elg.next() @@ -267,7 +308,12 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { func testCloseConnectionIfIdleButClosedRaceCondition() { let elg = EmbeddedEventLoopGroup(loops: 1) - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: .init()) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: .init(), + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) let el1 = elg.next() @@ -288,7 +334,12 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { let el3 = elg.next() let el4 = elg.next() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: .init()) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: .init(), + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) for el in [el1, el2, el3, el4] { let connID = connections.createNewConnection(on: el) @@ -333,6 +384,8 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { XCTAssertEqual(connections.closeConnection(at: releaseIndex), lease) XCTAssertFalse(connections.isEmpty) + let backoffEL = connections.backoffNextConnectionAttempt(startingID) + XCTAssertIdentical(backoffEL, el2) guard let (failIndex, _) = connections.failConnection(startingID) else { return XCTFail("Expected that the connection is remembered") } @@ -343,7 +396,12 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { func testMigrationFromHTTP2() { let elg = EmbeddedEventLoopGroup(loops: 4) let generator = HTTPConnectionPool.Connection.ID.Generator() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: generator) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: generator, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) let el1 = elg.next() let el2 = elg.next() @@ -372,7 +430,12 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { func testMigrationFromHTTP2WithPendingRequestsWithRequiredEventLoop() { let elg = EmbeddedEventLoopGroup(loops: 4) let generator = HTTPConnectionPool.Connection.ID.Generator() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: generator) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: generator, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) let el1 = elg.next() let el2 = elg.next() @@ -408,10 +471,48 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { XCTAssertTrue(context.eventLoop === el3) } + func testMigrationFromHTTP2WithPendingRequestsWithRequiredEventLoopSameAsStartingConnections() { + let elg = EmbeddedEventLoopGroup(loops: 4) + let generator = HTTPConnectionPool.Connection.ID.Generator() + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: generator, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) + + let el1 = elg.next() + let el2 = elg.next() + + let conn1ID = generator.next() + let conn2ID = generator.next() + + connections.migrateFromHTTP2( + starting: [(conn1ID, el1)], + backingOff: [(conn2ID, el2)] + ) + + let stats = connections.stats + XCTAssertEqual(stats.idle, 0) + XCTAssertEqual(stats.leased, 0) + XCTAssertEqual(stats.connecting, 1) + XCTAssertEqual(stats.backingOff, 1) + + let conn1: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn1ID, eventLoop: el1) + let (_, context) = connections.newHTTP1ConnectionEstablished(conn1) + XCTAssertEqual(context.use, .generalPurpose) + XCTAssertTrue(context.eventLoop === el1) + } + func testMigrationFromHTTP2WithPendingRequestsWithPreferredEventLoop() { let elg = EmbeddedEventLoopGroup(loops: 4) let generator = HTTPConnectionPool.Connection.ID.Generator() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: generator) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: generator, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) let el1 = elg.next() let el2 = elg.next() @@ -450,7 +551,12 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { func testMigrationFromHTTP2WithAlreadyLeasedHTTP1Connection() { let elg = EmbeddedEventLoopGroup(loops: 4) let generator = HTTPConnectionPool.Connection.ID.Generator() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: generator) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: generator, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) let el1 = elg.next() let el2 = elg.next() let el3 = elg.next() @@ -494,7 +600,12 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { func testMigrationFromHTTP2WithMoreStartingConnectionsThanMaximumAllowedConccurentConnections() { let elg = EmbeddedEventLoopGroup(loops: 4) let generator = HTTPConnectionPool.Connection.ID.Generator() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 2, generator: generator) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 2, + generator: generator, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) let el1 = elg.next() let el2 = elg.next() @@ -529,7 +640,12 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { func testMigrationFromHTTP2StartsEnoghOverflowConnectionsForRequiredEventLoopRequests() { let elg = EmbeddedEventLoopGroup(loops: 4) let generator = HTTPConnectionPool.Connection.ID.Generator() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 1, generator: generator) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 1, + generator: generator, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) let el1 = elg.next() let el2 = elg.next() @@ -571,16 +687,24 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { let el2 = elg.next() let generator = HTTPConnectionPool.Connection.ID.Generator() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: generator) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: generator, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) let connID1 = connections.createNewConnection(on: el1) let context = connections.migrateToHTTP2() - XCTAssertEqual(context, .init( - backingOff: [], - starting: [(connID1, el1)], - close: [] - )) + XCTAssertEqual( + context, + .init( + backingOff: [], + starting: [(connID1, el1)], + close: [] + ) + ) let connID2 = generator.next() @@ -598,8 +722,7 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { extension HTTPConnectionPool.HTTP1Connections.HTTP1ToHTTP2MigrationContext: Equatable { public static func == (lhs: Self, rhs: Self) -> Bool { - return lhs.close == rhs.close && - lhs.starting.elementsEqual(rhs.starting, by: { $0.0 == $1.0 && $0.1 === $1.1 }) && - lhs.backingOff.elementsEqual(rhs.backingOff, by: { $0.0 == $1.0 && $0.1 === $1.1 }) + lhs.close == rhs.close && lhs.starting.elementsEqual(rhs.starting, by: { $0.0 == $1.0 && $0.1 === $1.1 }) + && lhs.backingOff.elementsEqual(rhs.backingOff, by: { $0.0 == $1.0 && $0.1 === $1.1 }) } } diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1StateTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1StateTests+XCTest.swift deleted file mode 100644 index 16377d07f..000000000 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1StateTests+XCTest.swift +++ /dev/null @@ -1,46 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTPConnectionPool+HTTP1StateTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTPConnectionPool_HTTP1StateMachineTests { - static var allTests: [(String, (HTTPConnectionPool_HTTP1StateMachineTests) -> () throws -> Void)] { - return [ - ("testCreatingAndFailingConnections", testCreatingAndFailingConnections), - ("testConnectionFailureBackoff", testConnectionFailureBackoff), - ("testCancelRequestWorks", testCancelRequestWorks), - ("testExecuteOnShuttingDownPool", testExecuteOnShuttingDownPool), - ("testRequestsAreQueuedIfAllConnectionsAreInUseAndRequestsAreDequeuedInOrder", testRequestsAreQueuedIfAllConnectionsAreInUseAndRequestsAreDequeuedInOrder), - ("testBestConnectionIsPicked", testBestConnectionIsPicked), - ("testConnectionAbortIsIgnoredIfThereAreNoQueuedRequests", testConnectionAbortIsIgnoredIfThereAreNoQueuedRequests), - ("testConnectionCloseLeadsToTumbleWeedIfThereNoQueuedRequests", testConnectionCloseLeadsToTumbleWeedIfThereNoQueuedRequests), - ("testConnectionAbortLeadsToNewConnectionsIfThereAreQueuedRequests", testConnectionAbortLeadsToNewConnectionsIfThereAreQueuedRequests), - ("testParkedConnectionTimesOut", testParkedConnectionTimesOut), - ("testConnectionPoolFullOfParkedConnectionsIsShutdownImmediately", testConnectionPoolFullOfParkedConnectionsIsShutdownImmediately), - ("testParkedConnectionTimesOutButIsAlsoClosedByRemote", testParkedConnectionTimesOutButIsAlsoClosedByRemote), - ("testConnectionBackoffVsShutdownRace", testConnectionBackoffVsShutdownRace), - ("testRequestThatTimesOutIsFailedWithLastConnectionCreationError", testRequestThatTimesOutIsFailedWithLastConnectionCreationError), - ("testRequestThatTimesOutBeforeAConnectionIsEstablishedIsFailedWithConnectTimeoutError", testRequestThatTimesOutBeforeAConnectionIsEstablishedIsFailedWithConnectTimeoutError), - ("testRequestThatTimesOutAfterAConnectionWasEstablishedSuccessfullyTimesOutWithGenericError", testRequestThatTimesOutAfterAConnectionWasEstablishedSuccessfullyTimesOutWithGenericError), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1StateTests.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1StateTests.swift index 49a6fb574..9146f0593 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1StateTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1StateTests.swift @@ -12,21 +12,27 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOEmbedded import NIOHTTP1 import NIOPosix import XCTest +@testable import AsyncHTTPClient + class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { func testCreatingAndFailingConnections() { + struct SomeError: Error, Equatable {} let elg = EmbeddedEventLoopGroup(loops: 4) defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) } var state = HTTPConnectionPool.StateMachine( idGenerator: .init(), - maximumConcurrentHTTP1Connections: 8 + maximumConcurrentHTTP1Connections: 8, + retryConnectionEstablishment: true, + preferHTTP1: true, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 ) var connections = MockConnectionPool() @@ -35,7 +41,7 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { // for the first eight requests, the pool should try to create new connections. for _ in 0..<8 { - let mockRequest = MockHTTPRequest(eventLoop: elg.next()) + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) guard case .createConnection(let connectionID, let connectionEL) = action.connection else { @@ -51,7 +57,7 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { // the next eight requests should only be queued. for _ in 0..<8 { - let mockRequest = MockHTTPRequest(eventLoop: elg.next()) + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) guard case .none = action.connection else { @@ -65,8 +71,6 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { // fail all connection attempts while let randomConnectionID = connections.randomStartingConnection() { - struct SomeError: Error, Equatable {} - XCTAssertNoThrow(try connections.failConnectionCreation(randomConnectionID)) let action = state.failedToCreateNewConnection(SomeError(), connectionID: randomConnectionID) @@ -86,9 +90,9 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { // cancel all queued requests while let request = queuer.timeoutRandomRequest() { - let cancelAction = state.cancelRequest(request) + let cancelAction = state.cancelRequest(request.0) XCTAssertEqual(cancelAction.connection, .none) - XCTAssertEqual(cancelAction.request, .cancelRequestTimeout(request)) + XCTAssertEqual(cancelAction.request, .failRequest(.init(request.1), SomeError(), cancelTimeout: true)) } // connection backoff done @@ -103,16 +107,91 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssert(connections.isEmpty) } + func testCreatingAndFailingConnectionsWithoutRetry() { + struct SomeError: Error, Equatable {} + let elg = EmbeddedEventLoopGroup(loops: 4) + defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) } + + var state = HTTPConnectionPool.StateMachine( + idGenerator: .init(), + maximumConcurrentHTTP1Connections: 8, + retryConnectionEstablishment: false, + preferHTTP1: true, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) + + var connections = MockConnectionPool() + var queuer = MockRequestQueuer() + + // for the first eight requests, the pool should try to create new connections. + + for _ in 0..<8 { + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) + let request = HTTPConnectionPool.Request(mockRequest) + let action = state.executeRequest(request) + guard case .createConnection(let connectionID, let connectionEL) = action.connection else { + return XCTFail("Unexpected connection action") + } + XCTAssertEqual(.scheduleRequestTimeout(for: request, on: mockRequest.eventLoop), action.request) + XCTAssert(connectionEL === mockRequest.eventLoop) + + XCTAssertNoThrow(try connections.createConnection(connectionID, on: connectionEL)) + XCTAssertNoThrow(try queuer.queue(mockRequest, id: request.id)) + } + + // the next eight requests should only be queued. + + for _ in 0..<8 { + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) + let request = HTTPConnectionPool.Request(mockRequest) + let action = state.executeRequest(request) + guard case .none = action.connection else { + return XCTFail("Unexpected connection action") + } + XCTAssertEqual(.scheduleRequestTimeout(for: request, on: mockRequest.eventLoop), action.request) + XCTAssertNoThrow(try queuer.queue(mockRequest, id: request.id)) + } + + // the first failure should cancel all requests because we have disabled connection establishtment retry + let randomConnectionID = connections.randomStartingConnection()! + XCTAssertNoThrow(try connections.failConnectionCreation(randomConnectionID)) + let action = state.failedToCreateNewConnection(SomeError(), connectionID: randomConnectionID) + XCTAssertEqual(action.connection, .none) + guard case .failRequestsAndCancelTimeouts(let requestsToFail, let requestError) = action.request else { + return XCTFail("Unexpected request action: \(action.request)") + } + XCTAssertEqualTypeAndValue(requestError, SomeError()) + for requestToFail in requestsToFail { + XCTAssertNoThrow(try queuer.fail(requestToFail.id, request: requestToFail.__testOnly_wrapped_request())) + } + + // all requests have been canceled and therefore nothing should happen if a connection fails + while let randomConnectionID = connections.randomStartingConnection() { + XCTAssertNoThrow(try connections.failConnectionCreation(randomConnectionID)) + let action = state.failedToCreateNewConnection(SomeError(), connectionID: randomConnectionID) + + XCTAssertEqual(action, .none) + } + + XCTAssert(queuer.isEmpty) + XCTAssert(connections.isEmpty) + } + func testConnectionFailureBackoff() { let elg = EmbeddedEventLoopGroup(loops: 4) defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) } var state = HTTPConnectionPool.StateMachine( idGenerator: .init(), - maximumConcurrentHTTP1Connections: 2 + maximumConcurrentHTTP1Connections: 2, + retryConnectionEstablishment: true, + preferHTTP1: true, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 ) - let mockRequest = MockHTTPRequest(eventLoop: elg.next()) + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) @@ -122,9 +201,12 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { guard case .createConnection(let connectionID, on: let connectionEL) = action.connection else { return XCTFail("Unexpected connection action: \(action.connection)") } - XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux + XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux - let failedConnect1 = state.failedToCreateNewConnection(HTTPClientError.connectTimeout, connectionID: connectionID) + let failedConnect1 = state.failedToCreateNewConnection( + HTTPClientError.connectTimeout, + connectionID: connectionID + ) XCTAssertEqual(failedConnect1.request, .none) guard case .scheduleBackoffTimer(connectionID, let backoffTimeAmount1, _) = failedConnect1.connection else { return XCTFail("Unexpected connection action: \(failedConnect1.connection)") @@ -137,9 +219,12 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { return XCTFail("Unexpected connection action: \(backoffDoneAction.connection)") } XCTAssertGreaterThan(newConnectionID, connectionID) - XCTAssert(connectionEL === newEventLoop) // XCTAssertIdentical not available on Linux + XCTAssert(connectionEL === newEventLoop) // XCTAssertIdentical not available on Linux - let failedConnect2 = state.failedToCreateNewConnection(HTTPClientError.connectTimeout, connectionID: newConnectionID) + let failedConnect2 = state.failedToCreateNewConnection( + HTTPClientError.connectTimeout, + connectionID: newConnectionID + ) XCTAssertEqual(failedConnect2.request, .none) guard case .scheduleBackoffTimer(newConnectionID, let backoffTimeAmount2, _) = failedConnect2.connection else { return XCTFail("Unexpected connection action: \(failedConnect2.connection)") @@ -152,7 +237,9 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { guard case .failRequest(let requestToFail, let requestError, cancelTimeout: false) = failRequest.request else { return XCTFail("Unexpected request action: \(action.request)") } - XCTAssert(requestToFail.__testOnly_wrapped_request() === mockRequest) // XCTAssertIdentical not available on Linux + + // XCTAssertIdentical not available on Linux + XCTAssert(requestToFail.__testOnly_wrapped_request() === mockRequest) XCTAssertEqual(requestError as? HTTPClientError, .connectTimeout) XCTAssertEqual(failRequest.connection, .none) @@ -166,10 +253,14 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { var state = HTTPConnectionPool.StateMachine( idGenerator: .init(), - maximumConcurrentHTTP1Connections: 2 + maximumConcurrentHTTP1Connections: 2, + retryConnectionEstablishment: true, + preferHTTP1: true, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 ) - let mockRequest = MockHTTPRequest(eventLoop: elg.next()) + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) let request = HTTPConnectionPool.Request(mockRequest) let executeAction = state.executeRequest(request) @@ -179,19 +270,21 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { guard case .createConnection(let connectionID, on: let connectionEL) = executeAction.connection else { return XCTFail("Unexpected connection action: \(executeAction.connection)") } - XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux + XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux // 2. cancel request let cancelAction = state.cancelRequest(request.id) - XCTAssertEqual(cancelAction.request, .cancelRequestTimeout(request.id)) + XCTAssertEqual(cancelAction.request, .failRequest(request, HTTPClientError.cancelled, cancelTimeout: true)) XCTAssertEqual(cancelAction.connection, .none) // 3. request timeout triggers to late XCTAssertEqual(state.timeoutRequest(request.id), .none, "To late timeout is ignored") // 4. succeed connection attempt - let connectedAction = state.newHTTP1ConnectionCreated(.__testOnly_connection(id: connectionID, eventLoop: connectionEL)) + let connectedAction = state.newHTTP1ConnectionCreated( + .__testOnly_connection(id: connectionID, eventLoop: connectionEL) + ) XCTAssertEqual(connectedAction.request, .none, "Request must not be executed") XCTAssertEqual(connectedAction.connection, .scheduleTimeoutTimer(connectionID, on: connectionEL)) } @@ -202,10 +295,14 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { var state = HTTPConnectionPool.StateMachine( idGenerator: .init(), - maximumConcurrentHTTP1Connections: 2 + maximumConcurrentHTTP1Connections: 2, + retryConnectionEstablishment: true, + preferHTTP1: true, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 ) - let mockRequest = MockHTTPRequest(eventLoop: elg.next()) + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) let request = HTTPConnectionPool.Request(mockRequest) let executeAction = state.executeRequest(request) @@ -215,15 +312,18 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { guard case .createConnection(let connectionID, on: let connectionEL) = executeAction.connection else { return XCTFail("Unexpected connection action: \(executeAction.connection)") } - XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux + XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux // 2. connection succeeds - let connection: HTTPConnectionPool.Connection = .__testOnly_connection(id: connectionID, eventLoop: connectionEL) + let connection: HTTPConnectionPool.Connection = .__testOnly_connection( + id: connectionID, + eventLoop: connectionEL + ) let connectedAction = state.newHTTP1ConnectionCreated(connection) guard case .executeRequest(request, connection, cancelTimeout: true) = connectedAction.request else { return XCTFail("Unexpected request action: \(connectedAction.request)") } - XCTAssert(request.__testOnly_wrapped_request() === mockRequest) // XCTAssertIdentical not available on Linux + XCTAssert(request.__testOnly_wrapped_request() === mockRequest) // XCTAssertIdentical not available on Linux XCTAssertEqual(connectedAction.connection, .none) // 3. shutdown @@ -239,11 +339,14 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssertEqual(cleanupContext.connectBackoff, []) // 4. execute another request - let finalMockRequest = MockHTTPRequest(eventLoop: elg.next()) + let finalMockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) let finalRequest = HTTPConnectionPool.Request(finalMockRequest) let failAction = state.executeRequest(finalRequest) XCTAssertEqual(failAction.connection, .none) - XCTAssertEqual(failAction.request, .failRequest(finalRequest, HTTPClientError.alreadyShutdown, cancelTimeout: false)) + XCTAssertEqual( + failAction.request, + .failRequest(finalRequest, HTTPClientError.alreadyShutdown, cancelTimeout: false) + ) // 5. close open connection let closeAction = state.http1ConnectionClosed(connectionID) @@ -264,16 +367,20 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { // Add eight requests to fill all connections for _ in 0..<8 { let eventLoop = elg.next() - guard let expectedConnection = connections.newestParkedConnection(for: eventLoop) ?? connections.newestParkedConnection else { + guard + let expectedConnection = connections.newestParkedConnection(for: eventLoop) + ?? connections.newestParkedConnection + else { return XCTFail("Expected to still have connections available") } - let mockRequest = MockHTTPRequest(eventLoop: eventLoop) + let mockRequest = MockHTTPScheduableRequest(eventLoop: eventLoop) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) XCTAssertEqual(action.connection, .cancelTimeoutTimer(expectedConnection.id)) - guard case .executeRequest(let returnedRequest, expectedConnection, cancelTimeout: false) = action.request else { + guard case .executeRequest(let returnedRequest, expectedConnection, cancelTimeout: false) = action.request + else { return XCTFail("Expected to execute a request next, but got: \(action.request)") } @@ -288,7 +395,7 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { var queuer = MockRequestQueuer() for _ in 0..<100 { let eventLoop = elg.next() - let mockRequest = MockHTTPRequest(eventLoop: eventLoop, requiresEventLoopForChannel: false) + let mockRequest = MockHTTPScheduableRequest(eventLoop: eventLoop, requiresEventLoopForChannel: false) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) @@ -347,7 +454,10 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { // 10% of the cases enforce the eventLoop let elRequired = (0..<10).randomElement().flatMap { $0 == 0 ? true : false }! - let mockRequest = MockHTTPRequest(eventLoop: reqEventLoop, requiresEventLoopForChannel: elRequired) + let mockRequest = MockHTTPScheduableRequest( + eventLoop: reqEventLoop, + requiresEventLoopForChannel: elRequired + ) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) @@ -359,7 +469,10 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssert(connEventLoop === reqEventLoop) XCTAssertEqual(action.request, .scheduleRequestTimeout(for: request, on: reqEventLoop)) - let connection: HTTPConnectionPool.Connection = .__testOnly_connection(id: connectionID, eventLoop: connEventLoop) + let connection: HTTPConnectionPool.Connection = .__testOnly_connection( + id: connectionID, + eventLoop: connEventLoop + ) let createdAction = state.newHTTP1ConnectionCreated(connection) XCTAssertEqual(createdAction.request, .executeRequest(request, connection, cancelTimeout: true)) XCTAssertEqual(createdAction.connection, .none) @@ -370,7 +483,10 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssertEqual(state.http1ConnectionClosed(connectionID), .none) case .cancelTimeoutTimer(let connectionID): - guard let expectedConnection = connections.newestParkedConnection(for: reqEventLoop) ?? connections.newestParkedConnection else { + guard + let expectedConnection = connections.newestParkedConnection(for: reqEventLoop) + ?? connections.newestParkedConnection + else { return XCTFail("Expected to have connections available") } @@ -378,7 +494,11 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssert(expectedConnection.eventLoop === reqEventLoop) } - XCTAssertEqual(connectionID, expectedConnection.id, "Request is scheduled on the connection we expected") + XCTAssertEqual( + connectionID, + expectedConnection.id, + "Request is scheduled on the connection we expected" + ) XCTAssertNoThrow(try connections.activateConnection(connectionID)) guard case .executeRequest(let request, let connection, cancelTimeout: false) = action.request else { @@ -388,8 +508,10 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssertNoThrow(try connections.execute(request.__testOnly_wrapped_request(), on: connection)) XCTAssertNoThrow(try connections.finishExecution(connection.id)) - XCTAssertEqual(state.http1ConnectionReleased(connection.id), - .init(request: .none, connection: .scheduleTimeoutTimer(connection.id, on: connection.eventLoop))) + XCTAssertEqual( + state.http1ConnectionReleased(connection.id), + .init(request: .none, connection: .scheduleTimeoutTimer(connection.id, on: connection.eventLoop)) + ) XCTAssertNoThrow(try connections.parkConnection(connectionID)) default: @@ -411,7 +533,7 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssertEqual(connections.parked, 8) // close a leased connection == abort - let mockRequest = MockHTTPRequest(eventLoop: elg.next()) + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) let request = HTTPConnectionPool.Request(mockRequest) guard let connectionToAbort = connections.newestParkedConnection else { return XCTFail("Expected to have a parked connection") @@ -461,11 +583,14 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { // Add eight requests to fill all connections for _ in 0..<8 { let eventLoop = elg.next() - guard let expectedConnection = connections.newestParkedConnection(for: eventLoop) ?? connections.newestParkedConnection else { + guard + let expectedConnection = connections.newestParkedConnection(for: eventLoop) + ?? connections.newestParkedConnection + else { return XCTFail("Expected to still have connections available") } - let mockRequest = MockHTTPRequest(eventLoop: eventLoop) + let mockRequest = MockHTTPScheduableRequest(eventLoop: eventLoop) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) @@ -482,7 +607,7 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { for _ in 0..<100 { let eventLoop = elg.next() - let mockRequest = MockHTTPRequest(eventLoop: eventLoop, requiresEventLoopForChannel: false) + let mockRequest = MockHTTPScheduableRequest(eventLoop: eventLoop, requiresEventLoopForChannel: false) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) @@ -508,12 +633,20 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { guard let newConnection = maybeNewConnection else { return XCTFail("Expected to get a new connection") } let afterRecreationAction = state.newHTTP1ConnectionCreated(newConnection) XCTAssertEqual(afterRecreationAction.connection, .none) - guard case .executeRequest(let request, newConnection, cancelTimeout: true) = afterRecreationAction.request else { + guard + case .executeRequest(let request, newConnection, cancelTimeout: true) = afterRecreationAction + .request + else { return XCTFail("Unexpected request action: \(action.request)") } XCTAssertEqual(request.id, queuedRequestsOrder.popFirst()) - XCTAssertNoThrow(try connections.execute(queuer.get(request.id, request: request.__testOnly_wrapped_request()), on: newConnection)) + XCTAssertNoThrow( + try connections.execute( + queuer.get(request.id, request: request.__testOnly_wrapped_request()), + on: newConnection + ) + ) case .none: XCTAssert(queuer.isEmpty) @@ -535,7 +668,7 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { return XCTFail("Expected to have one parked connection") } - let action = state.connectionIdleTimeout(connection.id) + let action = state.connectionIdleTimeout(connection.id, on: connection.eventLoop) XCTAssertEqual(action.connection, .closeConnection(connection, isShutdown: .no)) XCTAssertEqual(action.request, .none) XCTAssertNoThrow(try connections.closeConnection(connection)) @@ -583,7 +716,7 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssertEqual(state.http1ConnectionClosed(connection.id), .none) // triggered by timer - XCTAssertEqual(state.connectionIdleTimeout(connection.id), .none) + XCTAssertEqual(state.connectionIdleTimeout(connection.id, on: connection.eventLoop), .none) } func testConnectionBackoffVsShutdownRace() { @@ -592,10 +725,14 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { var state = HTTPConnectionPool.StateMachine( idGenerator: .init(), - maximumConcurrentHTTP1Connections: 6 + maximumConcurrentHTTP1Connections: 6, + retryConnectionEstablishment: true, + preferHTTP1: true, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 ) - let mockRequest = MockHTTPRequest(eventLoop: elg.next(), requiresEventLoopForChannel: false) + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next(), requiresEventLoopForChannel: false) let request = HTTPConnectionPool.Request(mockRequest) let executeAction = state.executeRequest(request) @@ -630,10 +767,14 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { var state = HTTPConnectionPool.StateMachine( idGenerator: .init(), - maximumConcurrentHTTP1Connections: 6 + maximumConcurrentHTTP1Connections: 6, + retryConnectionEstablishment: true, + preferHTTP1: true, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 ) - let mockRequest = MockHTTPRequest(eventLoop: elg.next(), requiresEventLoopForChannel: false) + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next(), requiresEventLoopForChannel: false) let request = HTTPConnectionPool.Request(mockRequest) let executeAction = state.executeRequest(request) @@ -643,7 +784,10 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssertEqual(executeAction.request, .scheduleRequestTimeout(for: request, on: mockRequest.eventLoop)) - let failAction = state.failedToCreateNewConnection(HTTPClientError.httpProxyHandshakeTimeout, connectionID: connectionID) + let failAction = state.failedToCreateNewConnection( + HTTPClientError.httpProxyHandshakeTimeout, + connectionID: connectionID + ) guard case .scheduleBackoffTimer(connectionID, backoff: _, on: let timerEL) = failAction.connection else { return XCTFail("Expected to create a backoff timer") } @@ -651,7 +795,10 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssertEqual(failAction.request, .none) let timeoutAction = state.timeoutRequest(request.id) - XCTAssertEqual(timeoutAction.request, .failRequest(request, HTTPClientError.httpProxyHandshakeTimeout, cancelTimeout: false)) + XCTAssertEqual( + timeoutAction.request, + .failRequest(request, HTTPClientError.httpProxyHandshakeTimeout, cancelTimeout: false) + ) XCTAssertEqual(timeoutAction.connection, .none) } @@ -661,10 +808,14 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { var state = HTTPConnectionPool.StateMachine( idGenerator: .init(), - maximumConcurrentHTTP1Connections: 6 + maximumConcurrentHTTP1Connections: 6, + retryConnectionEstablishment: true, + preferHTTP1: true, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 ) - let mockRequest = MockHTTPRequest(eventLoop: eventLoop.next(), requiresEventLoopForChannel: false) + let mockRequest = MockHTTPScheduableRequest(eventLoop: eventLoop.next(), requiresEventLoopForChannel: false) let request = HTTPConnectionPool.Request(mockRequest) let executeAction = state.executeRequest(request) @@ -674,7 +825,10 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssertEqual(executeAction.request, .scheduleRequestTimeout(for: request, on: mockRequest.eventLoop)) let timeoutAction = state.timeoutRequest(request.id) - XCTAssertEqual(timeoutAction.request, .failRequest(request, HTTPClientError.connectTimeout, cancelTimeout: false)) + XCTAssertEqual( + timeoutAction.request, + .failRequest(request, HTTPClientError.connectTimeout, cancelTimeout: false) + ) XCTAssertEqual(timeoutAction.connection, .none) } @@ -684,10 +838,14 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { var state = HTTPConnectionPool.StateMachine( idGenerator: .init(), - maximumConcurrentHTTP1Connections: 6 + maximumConcurrentHTTP1Connections: 6, + retryConnectionEstablishment: true, + preferHTTP1: true, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 ) - let mockRequest1 = MockHTTPRequest(eventLoop: elg.next(), requiresEventLoopForChannel: false) + let mockRequest1 = MockHTTPScheduableRequest(eventLoop: elg.next(), requiresEventLoopForChannel: false) let request1 = HTTPConnectionPool.Request(mockRequest1) let executeAction1 = state.executeRequest(request1) @@ -698,7 +856,7 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssertEqual(executeAction1.request, .scheduleRequestTimeout(for: request1, on: mockRequest1.eventLoop)) - let mockRequest2 = MockHTTPRequest(eventLoop: elg.next(), requiresEventLoopForChannel: false) + let mockRequest2 = MockHTTPScheduableRequest(eventLoop: elg.next(), requiresEventLoopForChannel: false) let request2 = HTTPConnectionPool.Request(mockRequest2) let executeAction2 = state.executeRequest(request2) @@ -709,7 +867,10 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssertEqual(executeAction2.request, .scheduleRequestTimeout(for: request2, on: connEL1)) - let failAction = state.failedToCreateNewConnection(HTTPClientError.httpProxyHandshakeTimeout, connectionID: connectionID1) + let failAction = state.failedToCreateNewConnection( + HTTPClientError.httpProxyHandshakeTimeout, + connectionID: connectionID1 + ) guard case .scheduleBackoffTimer(connectionID1, backoff: _, on: let timerEL) = failAction.connection else { return XCTFail("Expected to create a backoff timer") } @@ -723,7 +884,657 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssertEqual(createdAction.connection, .none) let timeoutAction = state.timeoutRequest(request2.id) - XCTAssertEqual(timeoutAction.request, .failRequest(request2, HTTPClientError.getConnectionFromPoolTimeout, cancelTimeout: false)) + XCTAssertEqual( + timeoutAction.request, + .failRequest(request2, HTTPClientError.getConnectionFromPoolTimeout, cancelTimeout: false) + ) XCTAssertEqual(timeoutAction.connection, .none) } + + func testPrewarmingSimpleFlow() throws { + let elg = EmbeddedEventLoopGroup(loops: 4) + defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) } + + var state = HTTPConnectionPool.StateMachine( + idGenerator: .init(), + maximumConcurrentHTTP1Connections: 8, + retryConnectionEstablishment: true, + preferHTTP1: true, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 4 + ) + + var connectionIDs = [HTTPConnectionPool.Connection.ID]() + var connections = MockConnectionPool() + + // attempt to send one request. + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) + let request = HTTPConnectionPool.Request(mockRequest) + var action = state.executeRequest(request) + guard case .createConnection(var connectionID, var connectionEL) = action.connection else { + return XCTFail("Unexpected connection action") + } + connectionIDs.append(connectionID) + XCTAssertEqual(.scheduleRequestTimeout(for: request, on: mockRequest.eventLoop), action.request) + + XCTAssertNoThrow(try connections.createConnection(connectionID, on: connectionEL)) + + // We're going to end up creating 5 connections immediately, even though only one is leased: the other 4 are pre-warmed. + for connectionIndex in 0..<5 { + let conn = try connections.succeedConnectionCreationHTTP1(connectionID) + let createdAction = state.newHTTP1ConnectionCreated(conn) + + switch createdAction.request { + case .executeRequest(_, let connection, _): + try connections.execute(mockRequest, on: connection) + case .none: + try connections.parkConnection(connectionID) + default: + return XCTFail( + "Unexpected request action \(createdAction.request), connection index: \(connectionIndex)" + ) + } + + if connectionIndex == 0, + case .createConnection(let newConnectionID, let newConnectionEL) = createdAction.connection + { + (connectionID, connectionEL) = (newConnectionID, newConnectionEL) + connectionIDs.append(connectionID) + XCTAssertNoThrow(try connections.createConnection(connectionID, on: connectionEL)) + } else if connectionIndex < 4, + case .scheduleTimeoutTimerAndCreateConnection(let timeoutID, let newConnectionID, let newConnectionEL) = + createdAction.connection + { + XCTAssertEqual(connectionID, timeoutID) + (connectionID, connectionEL) = (newConnectionID, newConnectionEL) + connectionIDs.append(connectionID) + XCTAssertNoThrow(try connections.createConnection(connectionID, on: connectionEL)) + } else if connectionIndex == 4, case .scheduleTimeoutTimer = createdAction.connection { + // Expected, the loop will terminate now. + () + } else { + return XCTFail( + "Unexpected connection action: \(createdAction.connection) with index \(connectionIndex)" + ) + } + } + + XCTAssertEqual(connections.count, 5) + XCTAssertEqual(connections.parked, 4) + XCTAssertEqual(connectionIDs.count, 5) + + // Now we complete the first request. + try connections.finishExecution(connectionIDs[0]) + action = state.http1ConnectionReleased(connectionIDs[0]) + guard case .scheduleTimeoutTimer = action.connection else { + return XCTFail("Unexpected action: \(action.connection)") + } + try connections.parkConnection(connectionIDs[0]) + + XCTAssertEqual(connections.count, 5) + XCTAssertEqual(connections.parked, 5) + XCTAssertEqual(connectionIDs.count, 5) + } + + func testPrewarmingCreatesUpToTheMax() throws { + let elg = EmbeddedEventLoopGroup(loops: 4) + defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) } + + var state = HTTPConnectionPool.StateMachine( + idGenerator: .init(), + maximumConcurrentHTTP1Connections: 8, + retryConnectionEstablishment: true, + preferHTTP1: true, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 4 + ) + + var connections = MockConnectionPool() + + // Attempt to send one request. Complete the connection creation immediately, deferring the next connection creation, and then complete the + // request. + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) + let request = HTTPConnectionPool.Request(mockRequest) + var action = state.executeRequest(request) + guard case .createConnection(var connectionID, var connectionEL) = action.connection else { + return XCTFail("Unexpected connection action") + } + XCTAssertEqual(.scheduleRequestTimeout(for: request, on: mockRequest.eventLoop), action.request) + + XCTAssertNoThrow(try connections.createConnection(connectionID, on: connectionEL)) + var conn = try connections.succeedConnectionCreationHTTP1(connectionID) + var createdAction = state.newHTTP1ConnectionCreated(conn) + guard case .createConnection(var newConnectionID, var newConnectionEL) = createdAction.connection else { + return XCTFail("Unexpected connection action") + } + try connections.execute(mockRequest, on: conn) + try connections.finishExecution(connectionID) + action = state.http1ConnectionReleased(connectionID) + + // Here the state machine has _again_ asked us to create a connection. This is because the pre-warming + // phase takes any opportunity to do that. + guard + case .scheduleTimeoutTimerAndCreateConnection(_, let veryDelayedConnectionID, let veryDelayedLoop) = action + .connection + else { + return XCTFail("Unexpected action: \(action.connection)") + } + try connections.parkConnection(connectionID) + + // At this stage we're gonna end up creating 3 connections. No outstanding requests are present, so + // we only need the pre-warmed set, which includes the one we already made. + // + // The first will ask for another connection + (connectionID, connectionEL) = (newConnectionID, newConnectionEL) + XCTAssertNoThrow(try connections.createConnection(connectionID, on: connectionEL)) + conn = try connections.succeedConnectionCreationHTTP1(connectionID) + createdAction = state.newHTTP1ConnectionCreated(conn) + try connections.parkConnection(connectionID) + + guard + case .scheduleTimeoutTimerAndCreateConnection(_, let nextConnectionID, let nextConnectionEL) = createdAction + .connection + else { + return XCTFail("Unexpected connection action: \(createdAction.connection)") + } + (newConnectionID, newConnectionEL) = (nextConnectionID, nextConnectionEL) + + // The second one only asks for a timeout. + (connectionID, connectionEL) = (newConnectionID, newConnectionEL) + XCTAssertNoThrow(try connections.createConnection(connectionID, on: connectionEL)) + conn = try connections.succeedConnectionCreationHTTP1(connectionID) + createdAction = state.newHTTP1ConnectionCreated(conn) + try connections.parkConnection(connectionID) + + guard case .scheduleTimeoutTimer = createdAction.connection else { + return XCTFail("Unexpected connection action") + } + + // Now we should complete the delayed connection request. This will also only ask for a timer. + (connectionID, connectionEL) = (veryDelayedConnectionID, veryDelayedLoop) + XCTAssertNoThrow(try connections.createConnection(connectionID, on: connectionEL)) + conn = try connections.succeedConnectionCreationHTTP1(connectionID) + createdAction = state.newHTTP1ConnectionCreated(conn) + try connections.parkConnection(connectionID) + + guard case .scheduleTimeoutTimer = createdAction.connection else { + return XCTFail("Unexpected connection action") + } + + XCTAssertEqual(connections.count, 4) + XCTAssertEqual(connections.parked, 4) + + // Now we start sending requests. The first 4 requests will be accompanied by requests to create new connections, + // because as each connection goes out, the pre-warming creates another. We'll let them succeed. + for _ in 0..<4 { + let eventLoop = elg.next() + + let mockRequest = MockHTTPScheduableRequest(eventLoop: eventLoop) + let request = HTTPConnectionPool.Request(mockRequest) + let action = state.executeRequest(request) + + guard + case .createConnectionAndCancelTimeoutTimer( + let newConnectionID, + let newConnectionLoop, + let activatedConnectionID + ) = action.connection + else { + return XCTFail("Unexpected connection action: \(action)") + } + + guard case .executeRequest(_, let connection, _) = action.request else { + return XCTFail("Expected to execute a request next, but got: \(action.request)") + } + + try connections.activateConnection(activatedConnectionID) + try connections.execute(mockRequest, on: connection) + + // Now create the new connection. + XCTAssertNoThrow(try connections.createConnection(newConnectionID, on: newConnectionLoop)) + conn = try connections.succeedConnectionCreationHTTP1(newConnectionID) + createdAction = state.newHTTP1ConnectionCreated(conn) + try connections.parkConnection(newConnectionID) + } + + XCTAssertEqual(connections.count, 8) + XCTAssertEqual(connections.parked, 4) + + // The next 4 should _not_ ask to create new connections. We're at the cap, and prewarming can't exceed it. + for _ in 0..<4 { + let eventLoop = elg.next() + + let mockRequest = MockHTTPScheduableRequest(eventLoop: eventLoop) + let request = HTTPConnectionPool.Request(mockRequest) + let action = state.executeRequest(request) + + guard case .cancelTimeoutTimer(let activatedConnectionID) = action.connection else { + return XCTFail("Unexpected connection action: \(action)") + } + + guard case .executeRequest(_, let connection, _) = action.request else { + return XCTFail("Expected to execute a request next, but got: \(action.request)") + } + + try connections.activateConnection(activatedConnectionID) + try connections.execute(mockRequest, on: connection) + } + + XCTAssertEqual(connections.count, 8) + XCTAssertEqual(connections.parked, 0) + + while let connectionID = connections.randomActiveConnection() { + try connections.finishExecution(connectionID) + action = state.http1ConnectionReleased(connectionID) + + guard case .scheduleTimeoutTimer = action.connection else { + return XCTFail("Unexpected connection action: \(action.connection)") + } + } + + XCTAssertEqual(connections.count, 8) + XCTAssertEqual(connections.parked, 0) + } + + func testPrewarmingAffectsConnectionFailure() throws { + struct SomeError: Error {} + let elg = EmbeddedEventLoopGroup(loops: 4) + defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) } + + var state = HTTPConnectionPool.StateMachine( + idGenerator: .init(), + maximumConcurrentHTTP1Connections: 8, + retryConnectionEstablishment: true, + preferHTTP1: true, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 4 + ) + + var connections = MockConnectionPool() + + // Attempt to send one request. Complete the connection creation immediately, deferring the next connection creation, and then complete the + // request. + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) + let request = HTTPConnectionPool.Request(mockRequest) + var action = state.executeRequest(request) + guard case .createConnection(var connectionID, var connectionEL) = action.connection else { + return XCTFail("Unexpected connection action") + } + XCTAssertEqual(.scheduleRequestTimeout(for: request, on: mockRequest.eventLoop), action.request) + + XCTAssertNoThrow(try connections.createConnection(connectionID, on: connectionEL)) + var conn = try connections.succeedConnectionCreationHTTP1(connectionID) + var createdAction = state.newHTTP1ConnectionCreated(conn) + guard case .createConnection(var newConnectionID, var newConnectionEL) = createdAction.connection else { + return XCTFail("Unexpected connection action") + } + try connections.execute(mockRequest, on: conn) + try connections.finishExecution(connectionID) + action = state.http1ConnectionReleased(connectionID) + + // Here the state machine has _again_ asked us to create a connection. This is because the pre-warming + // phase takes any opportunity to do that. + guard + case .scheduleTimeoutTimerAndCreateConnection(_, let veryDelayedConnectionID, let veryDelayedLoop) = action + .connection + else { + return XCTFail("Unexpected action: \(action.connection)") + } + try connections.parkConnection(connectionID) + + // At this stage we're gonna end up creating 3 connections. No outstanding requests are present, so + // we only need the pre-warmed set, which includes the one we already made. + // + // The first will ask for another connection + (connectionID, connectionEL) = (newConnectionID, newConnectionEL) + XCTAssertNoThrow(try connections.createConnection(connectionID, on: connectionEL)) + conn = try connections.succeedConnectionCreationHTTP1(connectionID) + createdAction = state.newHTTP1ConnectionCreated(conn) + try connections.parkConnection(connectionID) + + guard + case .scheduleTimeoutTimerAndCreateConnection(_, let nextConnectionID, let nextConnectionEL) = createdAction + .connection + else { + return XCTFail("Unexpected connection action: \(createdAction.connection)") + } + (newConnectionID, newConnectionEL) = (nextConnectionID, nextConnectionEL) + + // The second one only asks for a timeout. + (connectionID, connectionEL) = (newConnectionID, newConnectionEL) + XCTAssertNoThrow(try connections.createConnection(connectionID, on: connectionEL)) + conn = try connections.succeedConnectionCreationHTTP1(connectionID) + createdAction = state.newHTTP1ConnectionCreated(conn) + try connections.parkConnection(connectionID) + + guard case .scheduleTimeoutTimer = createdAction.connection else { + return XCTFail("Unexpected connection action") + } + + // Now we should complete the delayed connection request. This will also only ask for a timer. + (connectionID, connectionEL) = (veryDelayedConnectionID, veryDelayedLoop) + XCTAssertNoThrow(try connections.createConnection(connectionID, on: connectionEL)) + conn = try connections.succeedConnectionCreationHTTP1(connectionID) + createdAction = state.newHTTP1ConnectionCreated(conn) + try connections.parkConnection(connectionID) + + guard case .scheduleTimeoutTimer = createdAction.connection else { + return XCTFail("Unexpected connection action") + } + + XCTAssertEqual(connections.count, 4) + XCTAssertEqual(connections.parked, 4) + + // Now, one of these connections idle-fails. + let parked = connections.randomParkedConnection()! + try connections.closeConnection(parked) + action = state.http1ConnectionClosed(parked.id) + + guard case .createConnection(var id, on: let loop) = action.connection else { + return XCTFail("Unexpected connection action: \(action.connection)") + } + + // A reasonable request. But it fails! + // + // Let's do this next bit a few times to convince ourselves it's a real problem. + for _ in 0..<8 { + // We're asked to schedule a backoff timer. + action = state.failedToCreateNewConnection(SomeError(), connectionID: id) + guard case .scheduleBackoffTimer(let backoffID, _, _) = action.connection else { + return XCTFail("Unexpected connection action: \(action.connection)") + } + XCTAssertEqual(backoffID, id) + + // Once it passes, ask what to do. We'll be asked, again, to create a connection. + action = state.connectionCreationBackoffDone(backoffID) + guard case .createConnection(let backedOffID, on: let backedOffLoop) = action.connection else { + return XCTFail("Unexpected connection action: \(action.connection)") + } + XCTAssertNotEqual(backedOffID, id) + XCTAssertIdentical(backedOffLoop, loop) + id = backedOffID + } + + // Finally it works. + XCTAssertNoThrow(try connections.createConnection(id, on: loop)) + conn = try connections.succeedConnectionCreationHTTP1(id) + createdAction = state.newHTTP1ConnectionCreated(conn) + try connections.parkConnection(id) + + guard case .scheduleTimeoutTimer = createdAction.connection else { + return XCTFail("Unexpected connection action") + } + } + + func testIdleConnectionTimeoutHandlingWithPrewarming() throws { + let elg = EmbeddedEventLoopGroup(loops: 4) + defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) } + + var state = HTTPConnectionPool.StateMachine( + idGenerator: .init(), + maximumConcurrentHTTP1Connections: 8, + retryConnectionEstablishment: true, + preferHTTP1: true, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 4 + ) + + var connections = MockConnectionPool() + + // Attempt to send one request. Complete the connection creation immediately, deferring the next connection creation, and then complete the + // request. + var mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) + var request = HTTPConnectionPool.Request(mockRequest) + var action = state.executeRequest(request) + guard case .createConnection(var connectionID, var connectionEL) = action.connection else { + return XCTFail("Unexpected connection action") + } + XCTAssertEqual(.scheduleRequestTimeout(for: request, on: mockRequest.eventLoop), action.request) + + XCTAssertNoThrow(try connections.createConnection(connectionID, on: connectionEL)) + var conn = try connections.succeedConnectionCreationHTTP1(connectionID) + var createdAction = state.newHTTP1ConnectionCreated(conn) + guard case .createConnection(var newConnectionID, var newConnectionEL) = createdAction.connection else { + return XCTFail("Unexpected connection action") + } + try connections.execute(mockRequest, on: conn) + try connections.finishExecution(connectionID) + action = state.http1ConnectionReleased(connectionID) + + // Here the state machine has _again_ asked us to create a connection. This is because the pre-warming + // phase takes any opportunity to do that. + guard + case .scheduleTimeoutTimerAndCreateConnection(_, let veryDelayedConnectionID, let veryDelayedLoop) = action + .connection + else { + return XCTFail("Unexpected action: \(action.connection)") + } + try connections.parkConnection(connectionID) + + // At this stage we're gonna end up creating 3 connections. No outstanding requests are present, so + // we only need the pre-warmed set, which includes the one we already made. + // + // The first will ask for another connection + (connectionID, connectionEL) = (newConnectionID, newConnectionEL) + XCTAssertNoThrow(try connections.createConnection(connectionID, on: connectionEL)) + conn = try connections.succeedConnectionCreationHTTP1(connectionID) + createdAction = state.newHTTP1ConnectionCreated(conn) + try connections.parkConnection(connectionID) + + guard + case .scheduleTimeoutTimerAndCreateConnection(_, let nextConnectionID, let nextConnectionEL) = createdAction + .connection + else { + return XCTFail("Unexpected connection action: \(createdAction.connection)") + } + (newConnectionID, newConnectionEL) = (nextConnectionID, nextConnectionEL) + + // The second one only asks for a timeout. + (connectionID, connectionEL) = (newConnectionID, newConnectionEL) + XCTAssertNoThrow(try connections.createConnection(connectionID, on: connectionEL)) + conn = try connections.succeedConnectionCreationHTTP1(connectionID) + createdAction = state.newHTTP1ConnectionCreated(conn) + try connections.parkConnection(connectionID) + + guard case .scheduleTimeoutTimer = createdAction.connection else { + return XCTFail("Unexpected connection action") + } + + // Now we should complete the delayed connection request. This will also only ask for a timer. + (connectionID, connectionEL) = (veryDelayedConnectionID, veryDelayedLoop) + XCTAssertNoThrow(try connections.createConnection(connectionID, on: connectionEL)) + conn = try connections.succeedConnectionCreationHTTP1(connectionID) + createdAction = state.newHTTP1ConnectionCreated(conn) + try connections.parkConnection(connectionID) + + guard case .scheduleTimeoutTimer = createdAction.connection else { + return XCTFail("Unexpected connection action") + } + + XCTAssertEqual(connections.count, 4) + XCTAssertEqual(connections.parked, 4) + + // Now, the idle timeout timer fires. We can do this a few times, it'll keep + // re-arming. + for _ in 0..<8 { + action = state.connectionIdleTimeout(connectionID, on: connectionEL) + guard case .scheduleTimeoutTimer = createdAction.connection else { + return XCTFail("Unexpected connection action") + } + } + + // Let's force another connection to be created for a request. + mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) + request = HTTPConnectionPool.Request(mockRequest) + action = state.executeRequest(request) + guard + case .createConnectionAndCancelTimeoutTimer(let extraConnectionID, let extraConnectionEL, _) = action + .connection + else { + return XCTFail("Unexpected connection action: \(action.connection)") + } + guard case .executeRequest(_, let requestConnection, _) = action.request else { + return XCTFail("Unexpected request action") + } + + XCTAssertNoThrow(try connections.createConnection(extraConnectionID, on: extraConnectionEL)) + conn = try connections.succeedConnectionCreationHTTP1(extraConnectionID) + createdAction = state.newHTTP1ConnectionCreated(conn) + guard case .scheduleTimeoutTimer = createdAction.connection else { + return XCTFail("Unexpected connection action: \(createdAction.connection)") + } + try connections.activateConnection(requestConnection.id) + try connections.execute(mockRequest, on: requestConnection) + try connections.finishExecution(requestConnection.id) + try connections.parkConnection(requestConnection.id) + action = state.http1ConnectionReleased(requestConnection.id) + + // Back to idle. + guard case .scheduleTimeoutTimer = action.connection else { + return XCTFail("Unexpected action: \(action.connection)") + } + try connections.parkConnection(extraConnectionID) + + XCTAssertEqual(connections.count, 5) + XCTAssertEqual(connections.parked, 5) + + // This time when the idle timeout fires, we're actually asked to close the connection. + action = state.connectionIdleTimeout(connectionID, on: connectionEL) + guard case .closeConnection = action.connection else { + return XCTFail("Unexpected connection action: \(createdAction.connection)") + } + } + + func testPrewarmingForcesReCreationOfConnectionsWhenTheyHitMaxUses() throws { + let elg = EmbeddedEventLoopGroup(loops: 4) + defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) } + + // The scenario we want to hit can only happen when there is never a spare pre-warmed connection + // in the pool _and_ we can't create more. The easiest way to test this is to just + // create pre-warmed connections up to the pool limit, which they won't pass. + var state = HTTPConnectionPool.StateMachine( + idGenerator: .init(), + maximumConcurrentHTTP1Connections: 8, + retryConnectionEstablishment: true, + preferHTTP1: true, + maximumConnectionUses: 1, + preWarmedHTTP1ConnectionCount: 8 + ) + + var connections = MockConnectionPool() + + // Attempt to send one request. Complete the connection creation immediately, deferring the next connection creation, but don't + // complete the request. + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) + let request = HTTPConnectionPool.Request(mockRequest) + var action = state.executeRequest(request) + guard case .createConnection(var connectionID, var connectionEL) = action.connection else { + return XCTFail("Unexpected connection action") + } + XCTAssertEqual(.scheduleRequestTimeout(for: request, on: mockRequest.eventLoop), action.request) + + XCTAssertNoThrow(try connections.createConnection(connectionID, on: connectionEL)) + var conn = try connections.succeedConnectionCreationHTTP1(connectionID) + var createdAction = state.newHTTP1ConnectionCreated(conn) + try connections.parkConnection(connectionID) + guard case .createConnection(var newConnectionID, var newConnectionEL) = createdAction.connection else { + return XCTFail("Unexpected connection action") + } + guard case .executeRequest(_, let requestConn, _) = createdAction.request else { + return XCTFail("Unexpected request action: \(action.request)") + } + + // At this stage we're gonna end up creating 7 more connections. No outstanding requests are present, so + // we only need the pre-warmed set, which includes the one we already made. + // + // The first six will ask for another connection. + for _ in 0..<6 { + (connectionID, connectionEL) = (newConnectionID, newConnectionEL) + XCTAssertNoThrow(try connections.createConnection(connectionID, on: connectionEL)) + conn = try connections.succeedConnectionCreationHTTP1(connectionID) + createdAction = state.newHTTP1ConnectionCreated(conn) + try connections.parkConnection(connectionID) + + guard + case .scheduleTimeoutTimerAndCreateConnection(_, let nextConnectionID, let nextConnectionEL) = + createdAction.connection + else { + return XCTFail("Unexpected connection action: \(createdAction.connection)") + } + (newConnectionID, newConnectionEL) = (nextConnectionID, nextConnectionEL) + } + + // The seventh one only asks for a timeout. + (connectionID, connectionEL) = (newConnectionID, newConnectionEL) + XCTAssertNoThrow(try connections.createConnection(connectionID, on: connectionEL)) + conn = try connections.succeedConnectionCreationHTTP1(connectionID) + createdAction = state.newHTTP1ConnectionCreated(conn) + try connections.parkConnection(connectionID) + + guard case .scheduleTimeoutTimer = createdAction.connection else { + return XCTFail("Unexpected connection action: \(createdAction.connection)") + } + + XCTAssertEqual(connections.count, 8) + XCTAssertEqual(connections.parked, 8) + + // Now we're gonna actually complete that request from earlier. + try connections.activateConnection(requestConn.id) + try connections.execute(mockRequest, on: requestConn) + try connections.finishExecution(requestConn.id) + action = state.http1ConnectionReleased(requestConn.id) + + // Here the state machine has asked us to close the connection and create a new one. That's because we've hit the + // max usages limit. + guard case .closeConnectionAndCreateConnection(let toClose, _, _) = action.connection else { + return XCTFail("Unexpected action: \(action.connection)") + } + try connections.closeConnection(toClose) + + // We won't bother doing it though, it's enough that it asked. + } + + func testFailConnectionRacesAgainstConnectionCreationFailed() { + let elg = EmbeddedEventLoopGroup(loops: 4) + defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) } + + var state = HTTPConnectionPool.StateMachine( + idGenerator: .init(), + maximumConcurrentHTTP1Connections: 2, + retryConnectionEstablishment: true, + preferHTTP1: true, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) + + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) + let request = HTTPConnectionPool.Request(mockRequest) + + let executeAction = state.executeRequest(request) + XCTAssertEqual(.scheduleRequestTimeout(for: request, on: mockRequest.eventLoop), executeAction.request) + + // 1. connection attempt + guard case .createConnection(let connectionID, on: let connectionEL) = executeAction.connection else { + return XCTFail("Unexpected connection action: \(executeAction.connection)") + } + XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux + + // 2. connection fails – first with closed callback + + XCTAssertEqual(state.http1ConnectionClosed(connectionID), .none) + + // 3. connection fails – with make connection callback + + let action = state.failedToCreateNewConnection( + IOError(errnoCode: -1, reason: "Test failure"), + connectionID: connectionID + ) + XCTAssertEqual(action.request, .none) + guard case .scheduleBackoffTimer(connectionID, _, on: let backoffTimerEL) = action.connection else { + XCTFail("Unexpected connection action: \(action.connection)") + return + } + XCTAssertIdentical(connectionEL, backoffTimerEL) + + } } diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2ConnectionsTest+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2ConnectionsTest+XCTest.swift deleted file mode 100644 index 95cade669..000000000 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2ConnectionsTest+XCTest.swift +++ /dev/null @@ -1,49 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTPConnectionPool+HTTP2ConnectionsTest+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTPConnectionPool_HTTP2ConnectionsTests { - static var allTests: [(String, (HTTPConnectionPool_HTTP2ConnectionsTests) -> () throws -> Void)] { - return [ - ("testCreatingConnections", testCreatingConnections), - ("testCreatingConnectionAndFailing", testCreatingConnectionAndFailing), - ("testFailConnectionRace", testFailConnectionRace), - ("testLeaseConnectionOfPreferredButUnavailableEL", testLeaseConnectionOfPreferredButUnavailableEL), - ("testLeaseConnectionOnRequiredButUnavailableEL", testLeaseConnectionOnRequiredButUnavailableEL), - ("testCloseConnectionIfIdle", testCloseConnectionIfIdle), - ("testCloseConnectionIfIdleButLeasedRaceCondition", testCloseConnectionIfIdleButLeasedRaceCondition), - ("testCloseConnectionIfIdleButClosedRaceCondition", testCloseConnectionIfIdleButClosedRaceCondition), - ("testCloseConnectionIfIdleRace", testCloseConnectionIfIdleRace), - ("testShutdown", testShutdown), - ("testLeasingAllConnections", testLeasingAllConnections), - ("testGoAway", testGoAway), - ("testNewMaxConcurrentStreamsSetting", testNewMaxConcurrentStreamsSetting), - ("testEventsAfterConnectionIsClosed", testEventsAfterConnectionIsClosed), - ("testLeaseOnPreferredEventLoopWithoutAnyAvailable", testLeaseOnPreferredEventLoopWithoutAnyAvailable), - ("testMigrationFromHTTP1", testMigrationFromHTTP1), - ("testMigrationToHTTP1", testMigrationToHTTP1), - ("testMigrationFromHTTP1WithPendingRequestsWithRequiredEventLoop", testMigrationFromHTTP1WithPendingRequestsWithRequiredEventLoop), - ("testMigrationFromHTTP1WithAlreadyEstablishedHTTP2Connection", testMigrationFromHTTP1WithAlreadyEstablishedHTTP2Connection), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2ConnectionsTest.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2ConnectionsTest.swift index 9e9ca1df6..dbfe90ff9 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2ConnectionsTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2ConnectionsTest.swift @@ -12,15 +12,16 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOEmbedded import XCTest +@testable import AsyncHTTPClient + class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { func testCreatingConnections() { let elg = EmbeddedEventLoopGroup(loops: 4) - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) let el1 = elg.next() let el2 = elg.next() @@ -32,7 +33,10 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { XCTAssertTrue(connections.hasConnectionThatCanOrWillBeAbleToExecuteRequests) XCTAssertTrue(connections.hasConnectionThatCanOrWillBeAbleToExecuteRequests(for: el1)) let conn1: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn1ID, eventLoop: el1) - let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished(conn1, maxConcurrentStreams: 100) + let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished( + conn1, + maxConcurrentStreams: 100 + ) XCTAssertEqual(conn1CreatedContext.availableStreams, 100) XCTAssertEqual(conn1CreatedContext.isIdle, true) XCTAssert(conn1CreatedContext.eventLoop === el1) @@ -46,7 +50,10 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { let conn2ID = connections.createNewConnection(on: el2) XCTAssertTrue(connections.hasConnectionThatCanOrWillBeAbleToExecuteRequests(for: el2)) let conn2: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn2ID, eventLoop: el2) - let (conn2Index, conn2CreatedContext) = connections.newHTTP2ConnectionEstablished(conn2, maxConcurrentStreams: 100) + let (conn2Index, conn2CreatedContext) = connections.newHTTP2ConnectionEstablished( + conn2, + maxConcurrentStreams: 100 + ) XCTAssertEqual(conn1CreatedContext.availableStreams, 100) XCTAssertTrue(conn1CreatedContext.isIdle) XCTAssert(conn2CreatedContext.eventLoop === el2) @@ -59,7 +66,7 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { func testCreatingConnectionAndFailing() { let elg = EmbeddedEventLoopGroup(loops: 4) - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) let el1 = elg.next() let el2 = elg.next() @@ -83,7 +90,9 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { XCTAssert(conn1FailContext.eventLoop === el1) XCTAssertFalse(connections.hasConnectionThatCanOrWillBeAbleToExecuteRequests) XCTAssertFalse(connections.hasConnectionThatCanOrWillBeAbleToExecuteRequests(for: el1)) - let (replaceConn1ID, replaceConn1EL) = connections.createNewConnectionByReplacingClosedConnection(at: conn1FailIndex) + let (replaceConn1ID, replaceConn1EL) = connections.createNewConnectionByReplacingClosedConnection( + at: conn1FailIndex + ) XCTAssert(replaceConn1EL === el1) XCTAssertEqual(replaceConn1ID, 1) XCTAssertTrue(connections.hasConnectionThatCanOrWillBeAbleToExecuteRequests) @@ -108,7 +117,7 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { let el1 = elg.next() - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) // connection is idle let conn1ID = connections.createNewConnection(on: el1) @@ -130,7 +139,7 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { let el4 = elg.next() let el5 = elg.next() - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) XCTAssertFalse(connections.hasConnectionThatCanOrWillBeAbleToExecuteRequests) for el in [el1, el2, el3, el4] { XCTAssertFalse(connections.hasConnectionThatCanOrWillBeAbleToExecuteRequests(for: el)) @@ -155,7 +164,7 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { let el4 = elg.next() let el5 = elg.next() - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) XCTAssertFalse(connections.hasConnectionThatCanOrWillBeAbleToExecuteRequests) for el in [el1, el2, el3, el4] { XCTAssertFalse(connections.hasConnectionThatCanOrWillBeAbleToExecuteRequests(for: el)) @@ -177,7 +186,7 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { let el1 = elg.next() - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) // connection is idle let conn1ID = connections.createNewConnection(on: el1) @@ -201,7 +210,7 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { let el1 = elg.next() - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) // connection is idle let conn1ID = connections.createNewConnection(on: el1) @@ -224,7 +233,7 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { let el1 = elg.next() - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) // connection is idle let conn1ID = connections.createNewConnection(on: el1) @@ -241,7 +250,7 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { let el1 = elg.next() - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) // connection is idle let conn1ID = connections.createNewConnection(on: el1) @@ -268,7 +277,7 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { let el5 = elg.next() let el6 = elg.next() - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) XCTAssertFalse(connections.hasConnectionThatCanOrWillBeAbleToExecuteRequests) for el in [el1, el2, el3, el4] { XCTAssertFalse(connections.hasConnectionThatCanOrWillBeAbleToExecuteRequests(for: el)) @@ -322,6 +331,8 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { XCTAssertEqual(connections.closeConnection(at: releaseIndex), leasedConn) XCTAssertFalse(connections.isEmpty) + let backoffEL = connections.backoffNextConnectionAttempt(startingID) + XCTAssertIdentical(el6, backoffEL) guard let (failIndex, _) = connections.failConnection(startingID) else { return XCTFail("Expected that the connection is remembered") } @@ -331,18 +342,24 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { func testLeasingAllConnections() { let elg = EmbeddedEventLoopGroup(loops: 4) - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) let el1 = elg.next() let conn1ID = connections.createNewConnection(on: el1) let conn1: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn1ID, eventLoop: el1) - let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished(conn1, maxConcurrentStreams: 100) + let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished( + conn1, + maxConcurrentStreams: 100 + ) XCTAssertEqual(conn1CreatedContext.availableStreams, 100) let (leasedConn1, leasdConnContext1) = connections.leaseStreams(at: conn1Index, count: 100) XCTAssertEqual(leasedConn1, conn1) XCTAssertEqual(leasdConnContext1.wasIdle, true) - XCTAssertNil(connections.leaseStream(onRequired: el1), "should not be able to lease stream because they are all already leased") + XCTAssertNil( + connections.leaseStream(onRequired: el1), + "should not be able to lease stream because they are all already leased" + ) let (_, releaseContext) = connections.releaseStream(conn1ID) XCTAssertFalse(releaseContext.isIdle) @@ -354,17 +371,23 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { XCTAssertEqual(leasedConn, conn1) XCTAssertEqual(leaseContext.wasIdle, false) - XCTAssertNil(connections.leaseStream(onRequired: el1), "should not be able to lease stream because they are all already leased") + XCTAssertNil( + connections.leaseStream(onRequired: el1), + "should not be able to lease stream because they are all already leased" + ) } func testGoAway() { let elg = EmbeddedEventLoopGroup(loops: 4) - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) let el1 = elg.next() let conn1ID = connections.createNewConnection(on: el1) let conn1: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn1ID, eventLoop: el1) - let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished(conn1, maxConcurrentStreams: 10) + let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished( + conn1, + maxConcurrentStreams: 10 + ) XCTAssertEqual(conn1CreatedContext.availableStreams, 10) let (leasedConn1, leasdConnContext1) = connections.leaseStreams(at: conn1Index, count: 2) @@ -386,7 +409,10 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { ) ) - XCTAssertNil(connections.leaseStream(onRequired: el1), "we should not be able to lease a stream because the connection is draining") + XCTAssertNil( + connections.leaseStream(onRequired: el1), + "we should not be able to lease a stream because the connection is draining" + ) // a server can potentially send more than one connection go away and we should not crash XCTAssertTrue(connections.goAwayReceived(conn1ID)?.eventLoop === el1) @@ -440,12 +466,15 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { func testNewMaxConcurrentStreamsSetting() { let elg = EmbeddedEventLoopGroup(loops: 4) - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) let el1 = elg.next() let conn1ID = connections.createNewConnection(on: el1) let conn1: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn1ID, eventLoop: el1) - let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished(conn1, maxConcurrentStreams: 1) + let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished( + conn1, + maxConcurrentStreams: 1 + ) XCTAssertEqual(conn1CreatedContext.availableStreams, 1) let (leasedConn1, leasdConnContext1) = connections.leaseStreams(at: conn1Index, count: 1) @@ -454,7 +483,8 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { XCTAssertNil(connections.leaseStream(onRequired: el1), "all streams are in use") - guard let (_, newSettingsContext1) = connections.newHTTP2MaxConcurrentStreamsReceived(conn1ID, newMaxStreams: 2) else { + guard let (_, newSettingsContext1) = connections.newHTTP2MaxConcurrentStreamsReceived(conn1ID, newMaxStreams: 2) + else { return XCTFail("Expected to get a new settings context") } XCTAssertEqual(newSettingsContext1.availableStreams, 1) @@ -467,7 +497,8 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { XCTAssertEqual(leasedConn2, conn1) XCTAssertEqual(leaseContext2.wasIdle, false) - guard let (_, newSettingsContext2) = connections.newHTTP2MaxConcurrentStreamsReceived(conn1ID, newMaxStreams: 1) else { + guard let (_, newSettingsContext2) = connections.newHTTP2MaxConcurrentStreamsReceived(conn1ID, newMaxStreams: 1) + else { return XCTFail("Expected to get a new settings context") } XCTAssertEqual(newSettingsContext2.availableStreams, 0) @@ -495,12 +526,15 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { func testEventsAfterConnectionIsClosed() { let elg = EmbeddedEventLoopGroup(loops: 2) - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) let el1 = elg.next() let conn1ID = connections.createNewConnection(on: el1) let conn1: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn1ID, eventLoop: el1) - let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished(conn1, maxConcurrentStreams: 1) + let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished( + conn1, + maxConcurrentStreams: 1 + ) XCTAssertEqual(conn1CreatedContext.availableStreams, 1) let (leasedConn1, leasdConnContext1) = connections.leaseStreams(at: conn1Index, count: 1) @@ -530,12 +564,15 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { func testLeaseOnPreferredEventLoopWithoutAnyAvailable() { let elg = EmbeddedEventLoopGroup(loops: 4) - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) let el1 = elg.next() let conn1ID = connections.createNewConnection(on: el1) let conn1: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn1ID, eventLoop: el1) - let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished(conn1, maxConcurrentStreams: 1) + let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished( + conn1, + maxConcurrentStreams: 1 + ) XCTAssertEqual(conn1CreatedContext.availableStreams, 1) let (leasedConn1, leasdConnContext1) = connections.leaseStreams(at: conn1Index, count: 1) XCTAssertEqual(leasedConn1, conn1) @@ -546,7 +583,7 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { func testMigrationFromHTTP1() { let elg = EmbeddedEventLoopGroup(loops: 4) - var connections = HTTPConnectionPool.HTTP2Connections(generator: .init()) + var connections = HTTPConnectionPool.HTTP2Connections(generator: .init(), maximumConnectionUses: nil) let el1 = elg.next() let el2 = elg.next() let conn1ID: HTTPConnectionPool.Connection.ID = 1 @@ -556,9 +593,11 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { starting: [(conn1ID, el1)], backingOff: [(conn2ID, el2)] ) - XCTAssertTrue(connections.createConnectionsAfterMigrationIfNeeded( - requiredEventLoopsOfPendingRequests: [el1, el2] - ).isEmpty) + XCTAssertTrue( + connections.createConnectionsAfterMigrationIfNeeded( + requiredEventLoopsOfPendingRequests: [el1, el2] + ).isEmpty + ) XCTAssertEqual( connections.stats, @@ -574,7 +613,10 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { ) let conn1: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn1ID, eventLoop: el1) - let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished(conn1, maxConcurrentStreams: 100) + let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished( + conn1, + maxConcurrentStreams: 100 + ) XCTAssertEqual(conn1CreatedContext.availableStreams, 100) let (leasedConn1, leasdConnContext1) = connections.leaseStreams(at: conn1Index, count: 2) @@ -598,7 +640,7 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { func testMigrationToHTTP1() { let elg = EmbeddedEventLoopGroup(loops: 4) let generator = HTTPConnectionPool.Connection.ID.Generator() - var connections = HTTPConnectionPool.HTTP2Connections(generator: generator) + var connections = HTTPConnectionPool.HTTP2Connections(generator: generator, maximumConnectionUses: nil) let el1 = elg.next() let el2 = elg.next() let el3 = elg.next() @@ -615,7 +657,10 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { ) let conn1: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn1ID, eventLoop: el1) - let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished(conn1, maxConcurrentStreams: 100) + let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished( + conn1, + maxConcurrentStreams: 100 + ) XCTAssertEqual(conn1CreatedContext.availableStreams, 100) let (leasedConn1, leasdConnContext1) = connections.leaseStreams(at: conn1Index, count: 2) @@ -663,7 +708,7 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { func testMigrationFromHTTP1WithPendingRequestsWithRequiredEventLoop() { let elg = EmbeddedEventLoopGroup(loops: 4) let generator = HTTPConnectionPool.Connection.ID.Generator() - var connections = HTTPConnectionPool.HTTP2Connections(generator: generator) + var connections = HTTPConnectionPool.HTTP2Connections(generator: generator, maximumConnectionUses: nil) let el1 = elg.next() let el2 = elg.next() let el3 = elg.next() @@ -696,7 +741,7 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { func testMigrationFromHTTP1WithAlreadyEstablishedHTTP2Connection() { let elg = EmbeddedEventLoopGroup(loops: 4) let generator = HTTPConnectionPool.Connection.ID.Generator() - var connections = HTTPConnectionPool.HTTP2Connections(generator: generator) + var connections = HTTPConnectionPool.HTTP2Connections(generator: generator, maximumConnectionUses: nil) let el1 = elg.next() let el2 = elg.next() let el3 = elg.next() @@ -714,9 +759,12 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { backingOff: [(conn3ID, el3)] ) - XCTAssertTrue(connections.createConnectionsAfterMigrationIfNeeded( - requiredEventLoopsOfPendingRequests: [el1, el2, el3] - ).isEmpty, "we still have an active connection for el1 and should not create a new one") + XCTAssertTrue( + connections.createConnectionsAfterMigrationIfNeeded( + requiredEventLoopsOfPendingRequests: [el1, el2, el3] + ).isEmpty, + "we still have an active connection for el1 and should not create a new one" + ) guard let (leasedConn, _) = connections.leaseStream(onRequired: el1) else { return XCTFail("could not lease stream on el1") diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2StateMachineTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2StateMachineTests+XCTest.swift deleted file mode 100644 index 9dca0c934..000000000 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2StateMachineTests+XCTest.swift +++ /dev/null @@ -1,49 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTPConnectionPool+HTTP2StateMachineTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTPConnectionPool_HTTP2StateMachineTests { - static var allTests: [(String, (HTTPConnectionPool_HTTP2StateMachineTests) -> () throws -> Void)] { - return [ - ("testCreatingOfConnection", testCreatingOfConnection), - ("testConnectionFailureBackoff", testConnectionFailureBackoff), - ("testCancelRequestWorks", testCancelRequestWorks), - ("testExecuteOnShuttingDownPool", testExecuteOnShuttingDownPool), - ("testHTTP1ToHTTP2MigrationAndShutdownIfFirstConnectionIsHTTP1", testHTTP1ToHTTP2MigrationAndShutdownIfFirstConnectionIsHTTP1), - ("testSchedulingAndCancelingOfIdleTimeout", testSchedulingAndCancelingOfIdleTimeout), - ("testConnectionTimeout", testConnectionTimeout), - ("testConnectionEstablishmentFailure", testConnectionEstablishmentFailure), - ("testGoAwayOnIdleConnection", testGoAwayOnIdleConnection), - ("testGoAwayWithLeasedStream", testGoAwayWithLeasedStream), - ("testGoAwayWithPendingRequestsStartsNewConnection", testGoAwayWithPendingRequestsStartsNewConnection), - ("testMigrationFromHTTP1ToHTTP2", testMigrationFromHTTP1ToHTTP2), - ("testMigrationFromHTTP1ToHTTP2WhileShuttingDown", testMigrationFromHTTP1ToHTTP2WhileShuttingDown), - ("testMigrationFromHTTP1ToHTTP2WithAlreadyStartedHTTP1Connections", testMigrationFromHTTP1ToHTTP2WithAlreadyStartedHTTP1Connections), - ("testHTTP2toHTTP1Migration", testHTTP2toHTTP1Migration), - ("testHTTP2toHTTP1MigrationDuringShutdown", testHTTP2toHTTP1MigrationDuringShutdown), - ("testConnectionIsImmediatelyCreatedAfterBackoffTimerFires", testConnectionIsImmediatelyCreatedAfterBackoffTimerFires), - ("testMaxConcurrentStreamsIsRespected", testMaxConcurrentStreamsIsRespected), - ("testEventsAfterConnectionIsClosed", testEventsAfterConnectionIsClosed), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2StateMachineTests.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2StateMachineTests.swift index 825ffc9b3..8fead4f4d 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2StateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2StateMachineTests.swift @@ -12,13 +12,14 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOEmbedded import NIOHTTP1 import NIOPosix import XCTest +@testable import AsyncHTTPClient + private typealias Action = HTTPConnectionPool.StateMachine.Action private typealias ConnectionAction = HTTPConnectionPool.StateMachine.ConnectionAction private typealias RequestAction = HTTPConnectionPool.StateMachine.RequestAction @@ -29,10 +30,15 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { let el1 = elg.next() var connections = MockConnectionPool() var queuer = MockRequestQueuer() - var state = HTTPConnectionPool.HTTP2StateMachine(idGenerator: .init(), lifecycleState: .running) + var state = HTTPConnectionPool.HTTP2StateMachine( + idGenerator: .init(), + retryConnectionEstablishment: true, + lifecycleState: .running, + maximumConnectionUses: nil + ) /// first request should create a new connection - let mockRequest = MockHTTPRequest(eventLoop: el1) + let mockRequest = MockHTTPScheduableRequest(eventLoop: el1) let request = HTTPConnectionPool.Request(mockRequest) let executeAction = state.executeRequest(request) @@ -48,7 +54,7 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { /// subsequent requests should not create a connection for _ in 0..<9 { - let mockRequest = MockHTTPRequest(eventLoop: el1) + let mockRequest = MockHTTPScheduableRequest(eventLoop: el1) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) @@ -99,7 +105,7 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { /// 4 streams are available and therefore request should be executed immediately for _ in 0..<4 { - let mockRequest = MockHTTPRequest(eventLoop: el1, requiresEventLoopForChannel: true) + let mockRequest = MockHTTPScheduableRequest(eventLoop: el1, requiresEventLoopForChannel: true) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) @@ -122,14 +128,17 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { /// shutdown should only close one connection let shutdownAction = state.shutdown() XCTAssertEqual(shutdownAction.request, .none) - XCTAssertEqual(shutdownAction.connection, .cleanupConnections( - .init( - close: [conn], - cancel: [], - connectBackoff: [] - ), - isShutdown: .yes(unclean: false) - )) + XCTAssertEqual( + shutdownAction.connection, + .cleanupConnections( + .init( + close: [conn], + cancel: [], + connectBackoff: [] + ), + isShutdown: .yes(unclean: false) + ) + ) } func testConnectionFailureBackoff() { @@ -138,10 +147,12 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { var state = HTTPConnectionPool.HTTP2StateMachine( idGenerator: .init(), - lifecycleState: .running + retryConnectionEstablishment: true, + lifecycleState: .running, + maximumConnectionUses: nil ) - let mockRequest = MockHTTPRequest(eventLoop: elg.next()) + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) @@ -151,9 +162,12 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { guard case .createConnection(let connectionID, on: let connectionEL) = action.connection else { return XCTFail("Unexpected connection action: \(action.connection)") } - XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux + XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux - let failedConnect1 = state.failedToCreateNewConnection(HTTPClientError.connectTimeout, connectionID: connectionID) + let failedConnect1 = state.failedToCreateNewConnection( + HTTPClientError.connectTimeout, + connectionID: connectionID + ) XCTAssertEqual(failedConnect1.request, .none) guard case .scheduleBackoffTimer(connectionID, let backoffTimeAmount1, _) = failedConnect1.connection else { return XCTFail("Unexpected connection action: \(failedConnect1.connection)") @@ -166,9 +180,12 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { return XCTFail("Unexpected connection action: \(backoffDoneAction.connection)") } XCTAssertGreaterThan(newConnectionID, connectionID) - XCTAssert(connectionEL === newEventLoop) // XCTAssertIdentical not available on Linux + XCTAssert(connectionEL === newEventLoop) // XCTAssertIdentical not available on Linux - let failedConnect2 = state.failedToCreateNewConnection(HTTPClientError.connectTimeout, connectionID: newConnectionID) + let failedConnect2 = state.failedToCreateNewConnection( + HTTPClientError.connectTimeout, + connectionID: newConnectionID + ) XCTAssertEqual(failedConnect2.request, .none) guard case .scheduleBackoffTimer(newConnectionID, let backoffTimeAmount2, _) = failedConnect2.connection else { return XCTFail("Unexpected connection action: \(failedConnect2.connection)") @@ -181,7 +198,8 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { guard case .failRequest(let requestToFail, let requestError, cancelTimeout: false) = failRequest.request else { return XCTFail("Unexpected request action: \(action.request)") } - XCTAssert(requestToFail.__testOnly_wrapped_request() === mockRequest) // XCTAssertIdentical not available on Linux + // XCTAssertIdentical not available on Linux + XCTAssert(requestToFail.__testOnly_wrapped_request() === mockRequest) XCTAssertEqual(requestError as? HTTPClientError, .connectTimeout) XCTAssertEqual(failRequest.connection, .none) @@ -189,16 +207,91 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { XCTAssertEqual(state.connectionCreationBackoffDone(newConnectionID), .none) } + func testConnectionFailureWhileShuttingDown() { + struct SomeError: Error, Equatable {} + let elg = EmbeddedEventLoopGroup(loops: 4) + defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) } + + var state = HTTPConnectionPool.HTTP2StateMachine( + idGenerator: .init(), + retryConnectionEstablishment: false, + lifecycleState: .running, + maximumConnectionUses: nil + ) + + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) + let request = HTTPConnectionPool.Request(mockRequest) + + let action = state.executeRequest(request) + XCTAssertEqual(.scheduleRequestTimeout(for: request, on: mockRequest.eventLoop), action.request) + + // 1. connection attempt + guard case .createConnection(let connectionID, on: let connectionEL) = action.connection else { + return XCTFail("Unexpected connection action: \(action.connection)") + } + XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux + + // 2. initialise shutdown + let shutdownAction = state.shutdown() + XCTAssertEqual(shutdownAction.connection, .cleanupConnections(.init(), isShutdown: .no)) + guard case .failRequestsAndCancelTimeouts(let requestsToFail, let requestError) = shutdownAction.request else { + return XCTFail("Unexpected request action: \(action.request)") + } + XCTAssertEqualTypeAndValue(requestError, HTTPClientError.cancelled) + XCTAssertEqualTypeAndValue(requestsToFail, [request]) + + // 3. connection attempt fails + let failedConnectAction = state.failedToCreateNewConnection(SomeError(), connectionID: connectionID) + XCTAssertEqual(failedConnectAction.request, .none) + XCTAssertEqual(failedConnectAction.connection, .cleanupConnections(.init(), isShutdown: .yes(unclean: true))) + } + + func testConnectionFailureWithoutRetry() { + struct SomeError: Error, Equatable {} + let elg = EmbeddedEventLoopGroup(loops: 4) + defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) } + + var state = HTTPConnectionPool.HTTP2StateMachine( + idGenerator: .init(), + retryConnectionEstablishment: false, + lifecycleState: .running, + maximumConnectionUses: nil + ) + + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) + let request = HTTPConnectionPool.Request(mockRequest) + + let action = state.executeRequest(request) + XCTAssertEqual(.scheduleRequestTimeout(for: request, on: mockRequest.eventLoop), action.request) + + // 1. connection attempt + guard case .createConnection(let connectionID, on: let connectionEL) = action.connection else { + return XCTFail("Unexpected connection action: \(action.connection)") + } + XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux + + let failedConnectAction = state.failedToCreateNewConnection(SomeError(), connectionID: connectionID) + XCTAssertEqual(failedConnectAction.connection, .none) + guard case .failRequestsAndCancelTimeouts(let requestsToFail, let requestError) = failedConnectAction.request + else { + return XCTFail("Unexpected request action: \(action.request)") + } + XCTAssertEqualTypeAndValue(requestError, SomeError()) + XCTAssertEqualTypeAndValue(requestsToFail, [request]) + } + func testCancelRequestWorks() { let elg = EmbeddedEventLoopGroup(loops: 4) defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) } var state = HTTPConnectionPool.HTTP2StateMachine( idGenerator: .init(), - lifecycleState: .running + retryConnectionEstablishment: true, + lifecycleState: .running, + maximumConnectionUses: nil ) - let mockRequest = MockHTTPRequest(eventLoop: elg.next()) + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) let request = HTTPConnectionPool.Request(mockRequest) let executeAction = state.executeRequest(request) @@ -208,11 +301,11 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { guard case .createConnection(let connectionID, on: let connectionEL) = executeAction.connection else { return XCTFail("Unexpected connection action: \(executeAction.connection)") } - XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux + XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux // 2. cancel request let cancelAction = state.cancelRequest(request.id) - XCTAssertEqual(cancelAction.request, .cancelRequestTimeout(request.id)) + XCTAssertEqual(cancelAction.request, .failRequest(request, HTTPClientError.cancelled, cancelTimeout: true)) XCTAssertEqual(cancelAction.connection, .none) // 3. request timeout triggers to late @@ -233,10 +326,12 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { var state = HTTPConnectionPool.HTTP2StateMachine( idGenerator: .init(), - lifecycleState: .running + retryConnectionEstablishment: true, + lifecycleState: .running, + maximumConnectionUses: nil ) - let mockRequest = MockHTTPRequest(eventLoop: elg.next()) + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) let request = HTTPConnectionPool.Request(mockRequest) let executeAction = state.executeRequest(request) @@ -246,15 +341,18 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { guard case .createConnection(let connectionID, on: let connectionEL) = executeAction.connection else { return XCTFail("Unexpected connection action: \(executeAction.connection)") } - XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux + XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux // 2. connection succeeds - let connection: HTTPConnectionPool.Connection = .__testOnly_connection(id: connectionID, eventLoop: connectionEL) + let connection: HTTPConnectionPool.Connection = .__testOnly_connection( + id: connectionID, + eventLoop: connectionEL + ) let connectedAction = state.newHTTP2ConnectionEstablished(connection, maxConcurrentStreams: 100) guard case .executeRequestsAndCancelTimeouts([request], connection) = connectedAction.request else { return XCTFail("Unexpected request action: \(connectedAction.request)") } - XCTAssert(request.__testOnly_wrapped_request() === mockRequest) // XCTAssertIdentical not available on Linux + XCTAssert(request.__testOnly_wrapped_request() === mockRequest) // XCTAssertIdentical not available on Linux XCTAssertEqual(connectedAction.connection, .none) // 3. shutdown @@ -270,11 +368,14 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { XCTAssertEqual(cleanupContext.connectBackoff, []) // 4. execute another request - let finalMockRequest = MockHTTPRequest(eventLoop: elg.next()) + let finalMockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) let finalRequest = HTTPConnectionPool.Request(finalMockRequest) let failAction = state.executeRequest(finalRequest) XCTAssertEqual(failAction.connection, .none) - XCTAssertEqual(failAction.request, .failRequest(finalRequest, HTTPClientError.alreadyShutdown, cancelTimeout: false)) + XCTAssertEqual( + failAction.request, + .failRequest(finalRequest, HTTPClientError.alreadyShutdown, cancelTimeout: false) + ) // 5. close open connection let closeAction = state.http2ConnectionClosed(connectionID) @@ -287,11 +388,18 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { let el1 = elg.next() let idGenerator = HTTPConnectionPool.Connection.ID.Generator() - var http1State = HTTPConnectionPool.HTTP1StateMachine(idGenerator: idGenerator, maximumConcurrentConnections: 8, lifecycleState: .running) + var http1State = HTTPConnectionPool.HTTP1StateMachine( + idGenerator: idGenerator, + maximumConcurrentConnections: 8, + retryConnectionEstablishment: true, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0, + lifecycleState: .running + ) - let mockRequest1 = MockHTTPRequest(eventLoop: el1) + let mockRequest1 = MockHTTPScheduableRequest(eventLoop: el1) let request1 = HTTPConnectionPool.Request(mockRequest1) - let mockRequest2 = MockHTTPRequest(eventLoop: el1) + let mockRequest2 = MockHTTPScheduableRequest(eventLoop: el1) let request2 = HTTPConnectionPool.Request(mockRequest2) let executeAction1 = http1State.executeRequest(request1) @@ -313,7 +421,12 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { // second connection is a HTTP2 connection and we need to migrate let conn2: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn2ID, eventLoop: el1) - var http2State = HTTPConnectionPool.HTTP2StateMachine(idGenerator: idGenerator, lifecycleState: .running) + var http2State = HTTPConnectionPool.HTTP2StateMachine( + idGenerator: idGenerator, + retryConnectionEstablishment: true, + lifecycleState: .running, + maximumConnectionUses: nil + ) let http2ConnectAction = http2State.migrateFromHTTP1( http1Connections: http1State.connections, @@ -322,7 +435,10 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { newHTTP2Connection: conn2, maxConcurrentStreams: 100 ) - XCTAssertEqual(http2ConnectAction.connection, .migration(createConnections: [], closeConnections: [], scheduleTimeout: nil)) + XCTAssertEqual( + http2ConnectAction.connection, + .migration(createConnections: [], closeConnections: [], scheduleTimeout: nil) + ) guard case .executeRequestsAndCancelTimeouts([request2], conn2) = http2ConnectAction.request else { return XCTFail("Unexpected request action \(http2ConnectAction.request)") } @@ -334,11 +450,17 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { let shutdownAction = http2State.shutdown() XCTAssertEqual(shutdownAction.request, .none) - XCTAssertEqual(shutdownAction.connection, .cleanupConnections(.init( - close: [conn2], - cancel: [], - connectBackoff: [] - ), isShutdown: .no)) + XCTAssertEqual( + shutdownAction.connection, + .cleanupConnections( + .init( + close: [conn2], + cancel: [], + connectBackoff: [] + ), + isShutdown: .no + ) + ) let releaseAction = http2State.http1ConnectionReleased(conn1ID) XCTAssertEqual(releaseAction.request, .none) @@ -351,22 +473,40 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { // establish one idle http2 connection let idGenerator = HTTPConnectionPool.Connection.ID.Generator() - var http1Conns = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: idGenerator) + var http1Conns = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: idGenerator, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) let conn1ID = http1Conns.createNewConnection(on: el1) - var state = HTTPConnectionPool.HTTP2StateMachine(idGenerator: idGenerator, lifecycleState: .running) + var state = HTTPConnectionPool.HTTP2StateMachine( + idGenerator: idGenerator, + retryConnectionEstablishment: true, + lifecycleState: .running, + maximumConnectionUses: nil + ) let conn1 = HTTPConnectionPool.Connection.__testOnly_connection(id: conn1ID, eventLoop: el1) - let connectAction = state.migrateFromHTTP1(http1Connections: http1Conns, requests: .init(), newHTTP2Connection: conn1, maxConcurrentStreams: 100) + let connectAction = state.migrateFromHTTP1( + http1Connections: http1Conns, + requests: .init(), + newHTTP2Connection: conn1, + maxConcurrentStreams: 100 + ) XCTAssertEqual(connectAction.request, .none) - XCTAssertEqual(connectAction.connection, .migration( - createConnections: [], - closeConnections: [], - scheduleTimeout: (conn1ID, el1) - )) + XCTAssertEqual( + connectAction.connection, + .migration( + createConnections: [], + closeConnections: [], + scheduleTimeout: (conn1ID, el1) + ) + ) // execute request on idle connection - let mockRequest1 = MockHTTPRequest(eventLoop: el1) + let mockRequest1 = MockHTTPScheduableRequest(eventLoop: el1) let request1 = HTTPConnectionPool.Request(mockRequest1) let request1Action = state.executeRequest(request1) XCTAssertEqual(request1Action.request, .executeRequest(request1, conn1, cancelTimeout: false)) @@ -378,7 +518,7 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { XCTAssertEqual(closeStream1Action.connection, .scheduleTimeoutTimer(conn1ID, on: el1)) // execute request on idle connection with required event loop - let mockRequest2 = MockHTTPRequest(eventLoop: el1, requiresEventLoopForChannel: true) + let mockRequest2 = MockHTTPScheduableRequest(eventLoop: el1, requiresEventLoopForChannel: true) let request2 = HTTPConnectionPool.Request(mockRequest2) let request2Action = state.executeRequest(request2) XCTAssertEqual(request2Action.request, .executeRequest(request2, conn1, cancelTimeout: false)) @@ -396,18 +536,36 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { // establish one idle http2 connection let idGenerator = HTTPConnectionPool.Connection.ID.Generator() - var http1Conns = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: idGenerator) + var http1Conns = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: idGenerator, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) let conn1ID = http1Conns.createNewConnection(on: el1) - var state = HTTPConnectionPool.HTTP2StateMachine(idGenerator: idGenerator, lifecycleState: .running) + var state = HTTPConnectionPool.HTTP2StateMachine( + idGenerator: idGenerator, + retryConnectionEstablishment: true, + lifecycleState: .running, + maximumConnectionUses: nil + ) let conn1 = HTTPConnectionPool.Connection.__testOnly_connection(id: conn1ID, eventLoop: el1) - let connectAction = state.migrateFromHTTP1(http1Connections: http1Conns, requests: .init(), newHTTP2Connection: conn1, maxConcurrentStreams: 100) + let connectAction = state.migrateFromHTTP1( + http1Connections: http1Conns, + requests: .init(), + newHTTP2Connection: conn1, + maxConcurrentStreams: 100 + ) XCTAssertEqual(connectAction.request, .none) - XCTAssertEqual(connectAction.connection, .migration( - createConnections: [], - closeConnections: [], - scheduleTimeout: (conn1ID, el1) - )) + XCTAssertEqual( + connectAction.connection, + .migration( + createConnections: [], + closeConnections: [], + scheduleTimeout: (conn1ID, el1) + ) + ) // let the connection timeout let timeoutAction = state.connectionIdleTimeout(conn1ID) @@ -424,20 +582,38 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { // establish one idle http2 connection let idGenerator = HTTPConnectionPool.Connection.ID.Generator() - var http1Conns = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: idGenerator) + var http1Conns = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: idGenerator, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) let conn1ID = http1Conns.createNewConnection(on: el1) - var state = HTTPConnectionPool.HTTP2StateMachine(idGenerator: idGenerator, lifecycleState: .running) + var state = HTTPConnectionPool.HTTP2StateMachine( + idGenerator: idGenerator, + retryConnectionEstablishment: true, + lifecycleState: .running, + maximumConnectionUses: nil + ) let conn1 = HTTPConnectionPool.Connection.__testOnly_connection(id: conn1ID, eventLoop: el1) - let connectAction = state.migrateFromHTTP1(http1Connections: http1Conns, requests: .init(), newHTTP2Connection: conn1, maxConcurrentStreams: 100) + let connectAction = state.migrateFromHTTP1( + http1Connections: http1Conns, + requests: .init(), + newHTTP2Connection: conn1, + maxConcurrentStreams: 100 + ) XCTAssertEqual(connectAction.request, .none) - XCTAssertEqual(connectAction.connection, .migration( - createConnections: [], - closeConnections: [], - scheduleTimeout: (conn1ID, el1) - )) + XCTAssertEqual( + connectAction.connection, + .migration( + createConnections: [], + closeConnections: [], + scheduleTimeout: (conn1ID, el1) + ) + ) // create new http2 connection - let mockRequest1 = MockHTTPRequest(eventLoop: el2, requiresEventLoopForChannel: true) + let mockRequest1 = MockHTTPScheduableRequest(eventLoop: el2, requiresEventLoopForChannel: true) let request1 = HTTPConnectionPool.Request(mockRequest1) let executeAction = state.executeRequest(request1) XCTAssertEqual(executeAction.request, .scheduleRequestTimeout(for: request1, on: el2)) @@ -459,9 +635,19 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { // establish one idle http2 connection let idGenerator = HTTPConnectionPool.Connection.ID.Generator() - var http1Conns = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: idGenerator) + var http1Conns = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: idGenerator, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) let conn1ID = http1Conns.createNewConnection(on: el1) - var state = HTTPConnectionPool.HTTP2StateMachine(idGenerator: idGenerator, lifecycleState: .running) + var state = HTTPConnectionPool.HTTP2StateMachine( + idGenerator: idGenerator, + retryConnectionEstablishment: true, + lifecycleState: .running, + maximumConnectionUses: nil + ) let conn1 = HTTPConnectionPool.Connection.__testOnly_connection(id: conn1ID, eventLoop: el1) @@ -472,11 +658,14 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { maxConcurrentStreams: 100 ) XCTAssertEqual(connectAction.request, .none) - XCTAssertEqual(connectAction.connection, .migration( - createConnections: [], - closeConnections: [], - scheduleTimeout: (conn1ID, el1) - )) + XCTAssertEqual( + connectAction.connection, + .migration( + createConnections: [], + closeConnections: [], + scheduleTimeout: (conn1ID, el1) + ) + ) let goAwayAction = state.http2ConnectionGoAwayReceived(conn1ID) XCTAssertEqual(goAwayAction.request, .none) @@ -489,9 +678,19 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { // establish one idle http2 connection let idGenerator = HTTPConnectionPool.Connection.ID.Generator() - var http1Conns = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: idGenerator) + var http1Conns = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: idGenerator, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) let conn1ID = http1Conns.createNewConnection(on: el1) - var state = HTTPConnectionPool.HTTP2StateMachine(idGenerator: idGenerator, lifecycleState: .running) + var state = HTTPConnectionPool.HTTP2StateMachine( + idGenerator: idGenerator, + retryConnectionEstablishment: true, + lifecycleState: .running, + maximumConnectionUses: nil + ) let conn1 = HTTPConnectionPool.Connection.__testOnly_connection(id: conn1ID, eventLoop: el1) let connectAction = state.migrateFromHTTP1( @@ -501,14 +700,17 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { maxConcurrentStreams: 100 ) XCTAssertEqual(connectAction.request, .none) - XCTAssertEqual(connectAction.connection, .migration( - createConnections: [], - closeConnections: [], - scheduleTimeout: (conn1ID, el1) - )) + XCTAssertEqual( + connectAction.connection, + .migration( + createConnections: [], + closeConnections: [], + scheduleTimeout: (conn1ID, el1) + ) + ) // execute request on idle connection - let mockRequest1 = MockHTTPRequest(eventLoop: el1) + let mockRequest1 = MockHTTPScheduableRequest(eventLoop: el1) let request1 = HTTPConnectionPool.Request(mockRequest1) let request1Action = state.executeRequest(request1) XCTAssertEqual(request1Action.request, .executeRequest(request1, conn1, cancelTimeout: false)) @@ -530,9 +732,19 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { // establish one idle http2 connection let idGenerator = HTTPConnectionPool.Connection.ID.Generator() - var http1Conns = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: idGenerator) + var http1Conns = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: idGenerator, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) let conn1ID = http1Conns.createNewConnection(on: el1) - var state = HTTPConnectionPool.HTTP2StateMachine(idGenerator: idGenerator, lifecycleState: .running) + var state = HTTPConnectionPool.HTTP2StateMachine( + idGenerator: idGenerator, + retryConnectionEstablishment: true, + lifecycleState: .running, + maximumConnectionUses: nil + ) let conn1 = HTTPConnectionPool.Connection.__testOnly_connection(id: conn1ID, eventLoop: el1) let connectAction1 = state.migrateFromHTTP1( @@ -542,21 +754,24 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { maxConcurrentStreams: 1 ) XCTAssertEqual(connectAction1.request, .none) - XCTAssertEqual(connectAction1.connection, .migration( - createConnections: [], - closeConnections: [], - scheduleTimeout: (conn1ID, el1) - )) + XCTAssertEqual( + connectAction1.connection, + .migration( + createConnections: [], + closeConnections: [], + scheduleTimeout: (conn1ID, el1) + ) + ) // execute request - let mockRequest1 = MockHTTPRequest(eventLoop: el1) + let mockRequest1 = MockHTTPScheduableRequest(eventLoop: el1) let request1 = HTTPConnectionPool.Request(mockRequest1) let request1Action = state.executeRequest(request1) XCTAssertEqual(request1Action.request, .executeRequest(request1, conn1, cancelTimeout: false)) XCTAssertEqual(request1Action.connection, .cancelTimeoutTimer(conn1ID)) // queue request - let mockRequest2 = MockHTTPRequest(eventLoop: el1) + let mockRequest2 = MockHTTPScheduableRequest(eventLoop: el1) let request2 = HTTPConnectionPool.Request(mockRequest2) let request2Action = state.executeRequest(request2) XCTAssertEqual(request2Action.request, .scheduleRequestTimeout(for: request2, on: el1)) @@ -592,12 +807,19 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { let el1 = elg.next() var connections = MockConnectionPool() var queuer = MockRequestQueuer() - var state = HTTPConnectionPool.StateMachine(idGenerator: .init(), maximumConcurrentHTTP1Connections: 8) + var state = HTTPConnectionPool.StateMachine( + idGenerator: .init(), + maximumConcurrentHTTP1Connections: 8, + retryConnectionEstablishment: true, + preferHTTP1: true, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) /// first 8 request should create a new connection var connectionIDs: [HTTPConnectionPool.Connection.ID] = [] for _ in 0..<8 { - let mockRequest = MockHTTPRequest(eventLoop: el1) + let mockRequest = MockHTTPScheduableRequest(eventLoop: el1) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) guard case .createConnection(let connID, let eventLoop) = action.connection else { @@ -616,7 +838,7 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { /// after we reached the `maximumConcurrentHTTP1Connections`, we will not create new connections for _ in 0..<8 { - let mockRequest = MockHTTPRequest(eventLoop: el1) + let mockRequest = MockHTTPScheduableRequest(eventLoop: el1) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) XCTAssertEqual(action.connection, .none) @@ -640,11 +862,14 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { XCTAssertNoThrow(try connections.execute(request.__testOnly_wrapped_request(), on: conn1)) } - XCTAssertEqual(migrationAction.connection, .migration( - createConnections: [], - closeConnections: [], - scheduleTimeout: nil - )) + XCTAssertEqual( + migrationAction.connection, + .migration( + createConnections: [], + closeConnections: [], + scheduleTimeout: nil + ) + ) /// remaining connections should be closed immediately without executing any request for connID in connectionIDs.dropFirst() { @@ -678,10 +903,17 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { let el1 = elg.next() var connections = MockConnectionPool() var queuer = MockRequestQueuer() - var state = HTTPConnectionPool.StateMachine(idGenerator: .init(), maximumConcurrentHTTP1Connections: 8) + var state = HTTPConnectionPool.StateMachine( + idGenerator: .init(), + maximumConcurrentHTTP1Connections: 8, + retryConnectionEstablishment: true, + preferHTTP1: false, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) /// create a new connection - let mockRequest = MockHTTPRequest(eventLoop: el1) + let mockRequest = MockHTTPScheduableRequest(eventLoop: el1) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) guard case .createConnection(let conn1ID, let eventLoop) = action.connection else { @@ -720,12 +952,19 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { let el1 = elg.next() var connections = MockConnectionPool() var queuer = MockRequestQueuer() - var state = HTTPConnectionPool.StateMachine(idGenerator: .init(), maximumConcurrentHTTP1Connections: 8) + var state = HTTPConnectionPool.StateMachine( + idGenerator: .init(), + maximumConcurrentHTTP1Connections: 8, + retryConnectionEstablishment: true, + preferHTTP1: true, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) /// first 8 request should create a new connection var connectionIDs: [HTTPConnectionPool.Connection.ID] = [] for _ in 0..<8 { - let mockRequest = MockHTTPRequest(eventLoop: el1) + let mockRequest = MockHTTPScheduableRequest(eventLoop: el1) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) guard case .createConnection(let connID, let eventLoop) = action.connection else { @@ -740,7 +979,7 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { /// after we reached the `maximumConcurrentHTTP1Connections`, we will not create new connections for _ in 0..<8 { - let mockRequest = MockHTTPRequest(eventLoop: el1) + let mockRequest = MockHTTPScheduableRequest(eventLoop: el1) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) XCTAssertEqual(action.connection, .none) @@ -791,7 +1030,10 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { guard case .executeRequestsAndCancelTimeouts(let requests, let conn) = migrationAction.request else { return XCTFail("unexpected request action \(migrationAction.request)") } - XCTAssertEqual(migrationAction.connection, .migration(createConnections: [], closeConnections: [], scheduleTimeout: nil)) + XCTAssertEqual( + migrationAction.connection, + .migration(createConnections: [], closeConnections: [], scheduleTimeout: nil) + ) XCTAssertEqual(conn, http2Conn) XCTAssertEqual(requests.count, 10) @@ -855,10 +1097,17 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { let el2 = elg.next() var connections = MockConnectionPool() var queuer = MockRequestQueuer() - var state = HTTPConnectionPool.StateMachine(idGenerator: .init(), maximumConcurrentHTTP1Connections: 8) + var state = HTTPConnectionPool.StateMachine( + idGenerator: .init(), + maximumConcurrentHTTP1Connections: 8, + retryConnectionEstablishment: true, + preferHTTP1: false, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) // create http2 connection - let mockRequest = MockHTTPRequest(eventLoop: el1) + let mockRequest = MockHTTPScheduableRequest(eventLoop: el1) let request1 = HTTPConnectionPool.Request(mockRequest) let action1 = state.executeRequest(request1) guard case .createConnection(let http2ConnID, let http2EventLoop) = action1.connection else { @@ -870,11 +1119,11 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { XCTAssertNoThrow(try queuer.queue(mockRequest, id: request1.id)) let http2Conn: HTTPConnectionPool.Connection = .__testOnly_connection(id: http2ConnID, eventLoop: el1) XCTAssertNoThrow(try connections.succeedConnectionCreationHTTP2(http2ConnID, maxConcurrentStreams: 10)) - let migrationAction1 = state.newHTTP2ConnectionCreated(http2Conn, maxConcurrentStreams: 10) - guard case .executeRequestsAndCancelTimeouts(let requests, http2Conn) = migrationAction1.request else { - return XCTFail("unexpected request action \(migrationAction1.request)") + let executeAction1 = state.newHTTP2ConnectionCreated(http2Conn, maxConcurrentStreams: 10) + guard case .executeRequestsAndCancelTimeouts(let requests, http2Conn) = executeAction1.request else { + return XCTFail("unexpected request action \(executeAction1.request)") } - XCTAssertEqual(migrationAction1.connection, .migration(createConnections: [], closeConnections: [], scheduleTimeout: nil)) + XCTAssertEqual(requests.count, 1) for request in requests { XCTAssertNoThrow(try queuer.get(request.id, request: request.__testOnly_wrapped_request())) @@ -882,14 +1131,20 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { } // a request with new required event loop should create a new connection - let mockRequestWithRequiredEventLoop = MockHTTPRequest(eventLoop: el2, requiresEventLoopForChannel: true) + let mockRequestWithRequiredEventLoop = MockHTTPScheduableRequest( + eventLoop: el2, + requiresEventLoopForChannel: true + ) let requestWithRequiredEventLoop = HTTPConnectionPool.Request(mockRequestWithRequiredEventLoop) let action2 = state.executeRequest(requestWithRequiredEventLoop) guard case .createConnection(let http1ConnId, let http1EventLoop) = action2.connection else { return XCTFail("Unexpected connection action \(action2.connection)") } XCTAssertTrue(http1EventLoop === el2) - XCTAssertEqual(action2.request, .scheduleRequestTimeout(for: requestWithRequiredEventLoop, on: mockRequestWithRequiredEventLoop.eventLoop)) + XCTAssertEqual( + action2.request, + .scheduleRequestTimeout(for: requestWithRequiredEventLoop, on: mockRequestWithRequiredEventLoop.eventLoop) + ) XCTAssertNoThrow(try connections.createConnection(http1ConnId, on: el2)) XCTAssertNoThrow(try queuer.queue(mockRequestWithRequiredEventLoop, id: requestWithRequiredEventLoop.id)) @@ -900,7 +1155,10 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { guard case .executeRequest(let request2, http1Conn, cancelTimeout: true) = migrationAction2.request else { return XCTFail("unexpected request action \(migrationAction2.request)") } - guard case .migration(let createConnections, closeConnections: [], scheduleTimeout: nil) = migrationAction2.connection else { + guard + case .migration(let createConnections, closeConnections: [], scheduleTimeout: nil) = migrationAction2 + .connection + else { return XCTFail("unexpected connection action \(migrationAction2.connection)") } XCTAssertEqual(createConnections.map { $0.1.id }, [el2.id]) @@ -921,10 +1179,17 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { let el2 = elg.next() var connections = MockConnectionPool() var queuer = MockRequestQueuer() - var state = HTTPConnectionPool.StateMachine(idGenerator: .init(), maximumConcurrentHTTP1Connections: 8) + var state = HTTPConnectionPool.StateMachine( + idGenerator: .init(), + maximumConcurrentHTTP1Connections: 8, + retryConnectionEstablishment: true, + preferHTTP1: false, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) // create http2 connection - let mockRequest = MockHTTPRequest(eventLoop: el1) + let mockRequest = MockHTTPScheduableRequest(eventLoop: el1) let request1 = HTTPConnectionPool.Request(mockRequest) let action1 = state.executeRequest(request1) guard case .createConnection(let http2ConnID, let http2EventLoop) = action1.connection else { @@ -936,11 +1201,11 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { XCTAssertNoThrow(try queuer.queue(mockRequest, id: request1.id)) let http2Conn: HTTPConnectionPool.Connection = .__testOnly_connection(id: http2ConnID, eventLoop: el1) XCTAssertNoThrow(try connections.succeedConnectionCreationHTTP2(http2ConnID, maxConcurrentStreams: 10)) - let migrationAction1 = state.newHTTP2ConnectionCreated(http2Conn, maxConcurrentStreams: 10) - guard case .executeRequestsAndCancelTimeouts(let requests, http2Conn) = migrationAction1.request else { - return XCTFail("unexpected request action \(migrationAction1.request)") + let executeAction1 = state.newHTTP2ConnectionCreated(http2Conn, maxConcurrentStreams: 10) + guard case .executeRequestsAndCancelTimeouts(let requests, http2Conn) = executeAction1.request else { + return XCTFail("unexpected request action \(executeAction1.request)") } - XCTAssertEqual(migrationAction1.connection, .migration(createConnections: [], closeConnections: [], scheduleTimeout: nil)) + XCTAssertEqual(requests.count, 1) for request in requests { XCTAssertNoThrow(try queuer.get(request.id, request: request.__testOnly_wrapped_request())) @@ -948,14 +1213,20 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { } // a request with new required event loop should create a new connection - let mockRequestWithRequiredEventLoop = MockHTTPRequest(eventLoop: el2, requiresEventLoopForChannel: true) + let mockRequestWithRequiredEventLoop = MockHTTPScheduableRequest( + eventLoop: el2, + requiresEventLoopForChannel: true + ) let requestWithRequiredEventLoop = HTTPConnectionPool.Request(mockRequestWithRequiredEventLoop) let action2 = state.executeRequest(requestWithRequiredEventLoop) guard case .createConnection(let http1ConnId, let http1EventLoop) = action2.connection else { return XCTFail("Unexpected connection action \(action2.connection)") } XCTAssertTrue(http1EventLoop === el2) - XCTAssertEqual(action2.request, .scheduleRequestTimeout(for: requestWithRequiredEventLoop, on: mockRequestWithRequiredEventLoop.eventLoop)) + XCTAssertEqual( + action2.request, + .scheduleRequestTimeout(for: requestWithRequiredEventLoop, on: mockRequestWithRequiredEventLoop.eventLoop) + ) XCTAssertNoThrow(try connections.createConnection(http1ConnId, on: el2)) XCTAssertNoThrow(try queuer.queue(mockRequestWithRequiredEventLoop, id: requestWithRequiredEventLoop.id)) @@ -971,13 +1242,16 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { } XCTAssertTrue(queuer.isEmpty) - // if we established a new http/1 connection we should migrate back to http/1, + // if we established a new http/1 connection we should migrate to http/1, // close the connection and shutdown the pool let http1Conn: HTTPConnectionPool.Connection = .__testOnly_connection(id: http1ConnId, eventLoop: el2) XCTAssertNoThrow(try connections.succeedConnectionCreationHTTP1(http1ConnId)) let migrationAction2 = state.newHTTP1ConnectionCreated(http1Conn) XCTAssertEqual(migrationAction2.request, .none) - XCTAssertEqual(migrationAction2.connection, .migration(createConnections: [], closeConnections: [http1Conn], scheduleTimeout: nil)) + XCTAssertEqual( + migrationAction2.connection, + .migration(createConnections: [], closeConnections: [http1Conn], scheduleTimeout: nil) + ) // in http/1 state, we should close idle http2 connections XCTAssertNoThrow(try connections.finishExecution(http2Conn.id)) @@ -993,11 +1267,18 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { let el2 = elg.next() var connections = MockConnectionPool() var queuer = MockRequestQueuer() - var state = HTTPConnectionPool.StateMachine(idGenerator: .init(), maximumConcurrentHTTP1Connections: 8) + var state = HTTPConnectionPool.StateMachine( + idGenerator: .init(), + maximumConcurrentHTTP1Connections: 8, + retryConnectionEstablishment: true, + preferHTTP1: false, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) var connectionIDs: [HTTPConnectionPool.Connection.ID] = [] - for el in [el1, el2, el2] { - let mockRequest = MockHTTPRequest(eventLoop: el, requiresEventLoopForChannel: true) + for el in [el1, el2] { + let mockRequest = MockHTTPScheduableRequest(eventLoop: el, requiresEventLoopForChannel: true) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) guard case .createConnection(let connID, let eventLoop) = action.connection else { @@ -1010,7 +1291,7 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { XCTAssertNoThrow(try queuer.queue(mockRequest, id: request.id)) } - // fail the two connections for el2 + // fail the connection for el2 for connectionID in connectionIDs.dropFirst() { struct SomeError: Error {} XCTAssertNoThrow(try connections.failConnectionCreation(connectionID)) @@ -1023,16 +1304,14 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { } let http2ConnID1 = connectionIDs[0] let http2ConnID2 = connectionIDs[1] - let http2ConnID3 = connectionIDs[2] // let the first connection on el1 succeed as a http2 connection let http2Conn1: HTTPConnectionPool.Connection = .__testOnly_connection(id: http2ConnID1, eventLoop: el1) XCTAssertNoThrow(try connections.succeedConnectionCreationHTTP2(http2ConnID1, maxConcurrentStreams: 10)) - let migrationAction1 = state.newHTTP2ConnectionCreated(http2Conn1, maxConcurrentStreams: 10) - guard case .executeRequestsAndCancelTimeouts(let requests, http2Conn1) = migrationAction1.request else { - return XCTFail("unexpected request action \(migrationAction1.request)") + let connectionAction = state.newHTTP2ConnectionCreated(http2Conn1, maxConcurrentStreams: 10) + guard case .executeRequestsAndCancelTimeouts(let requests, http2Conn1) = connectionAction.request else { + return XCTFail("unexpected request action \(connectionAction.request)") } - XCTAssertEqual(migrationAction1.connection, .migration(createConnections: [], closeConnections: [], scheduleTimeout: nil)) XCTAssertEqual(requests.count, 1) for request in requests { XCTAssertNoThrow(try queuer.get(request.id, request: request.__testOnly_wrapped_request())) @@ -1051,14 +1330,6 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { } XCTAssertTrue(eventLoop2 === el2) XCTAssertNoThrow(try connections.createConnection(newHttp2ConnID2, on: el2)) - - // we now have a starting connection for el2 and another one backing off - - // if the backoff timer fires now for a connection on el2, we should *not* start a new connection - XCTAssertNoThrow(try connections.connectionBackoffTimerDone(http2ConnID3)) - let action3 = state.connectionCreationBackoffDone(http2ConnID3) - XCTAssertEqual(action3.request, .none) - XCTAssertEqual(action3.connection, .none) } func testMaxConcurrentStreamsIsRespected() { @@ -1076,7 +1347,7 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { // shall be queued. for i in 0..<1000 { let requestEL = elg.next() - let mockRequest = MockHTTPRequest(eventLoop: requestEL) + let mockRequest = MockHTTPScheduableRequest(eventLoop: requestEL) let request = HTTPConnectionPool.Request(mockRequest) let executeAction = state.executeRequest(request) @@ -1084,10 +1355,16 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { case 0: XCTAssertEqual(executeAction.connection, .cancelTimeoutTimer(generalPurposeConnection.id)) XCTAssertNoThrow(try connections.activateConnection(generalPurposeConnection.id)) - XCTAssertEqual(executeAction.request, .executeRequest(request, generalPurposeConnection, cancelTimeout: false)) + XCTAssertEqual( + executeAction.request, + .executeRequest(request, generalPurposeConnection, cancelTimeout: false) + ) XCTAssertNoThrow(try connections.execute(mockRequest, on: generalPurposeConnection)) case 1..<100: - XCTAssertEqual(executeAction.request, .executeRequest(request, generalPurposeConnection, cancelTimeout: false)) + XCTAssertEqual( + executeAction.request, + .executeRequest(request, generalPurposeConnection, cancelTimeout: false) + ) XCTAssertEqual(executeAction.connection, .none) XCTAssertNoThrow(try connections.execute(mockRequest, on: generalPurposeConnection)) case 100..<1000: @@ -1105,7 +1382,8 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { XCTAssertNoThrow(try connections.finishExecution(generalPurposeConnection.id)) let finishAction = state.http2ConnectionStreamClosed(generalPurposeConnection.id) XCTAssertEqual(finishAction.connection, .none) - guard case .executeRequestsAndCancelTimeouts(let requests, generalPurposeConnection) = finishAction.request else { + guard case .executeRequestsAndCancelTimeouts(let requests, generalPurposeConnection) = finishAction.request + else { return XCTFail("Unexpected request action: \(finishAction.request)") } guard requests.count == 1, let request = requests.first else { @@ -1120,11 +1398,23 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { // Next the server allows for more concurrent streams let newMaxStreams = 200 - XCTAssertNoThrow(try connections.newHTTP2ConnectionSettingsReceived(generalPurposeConnection.id, maxConcurrentStreams: newMaxStreams)) - let newMaxStreamsAction = state.newHTTP2MaxConcurrentStreamsReceived(generalPurposeConnection.id, newMaxStreams: newMaxStreams) + XCTAssertNoThrow( + try connections.newHTTP2ConnectionSettingsReceived( + generalPurposeConnection.id, + maxConcurrentStreams: newMaxStreams + ) + ) + let newMaxStreamsAction = state.newHTTP2MaxConcurrentStreamsReceived( + generalPurposeConnection.id, + newMaxStreams: newMaxStreams + ) XCTAssertEqual(newMaxStreamsAction.connection, .none) - guard case .executeRequestsAndCancelTimeouts(let requests, generalPurposeConnection) = newMaxStreamsAction.request else { - return XCTFail("Unexpected request action after new max concurrent stream setting: \(newMaxStreamsAction.request)") + guard + case .executeRequestsAndCancelTimeouts(let requests, generalPurposeConnection) = newMaxStreamsAction.request + else { + return XCTFail( + "Unexpected request action after new max concurrent stream setting: \(newMaxStreamsAction.request)" + ) } XCTAssertEqual(requests.count, 100, "Expected to execute 100 more requests") for request in requests { @@ -1141,7 +1431,8 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { XCTAssertNoThrow(try connections.finishExecution(generalPurposeConnection.id)) let finishAction = state.http2ConnectionStreamClosed(generalPurposeConnection.id) XCTAssertEqual(finishAction.connection, .none) - guard case .executeRequestsAndCancelTimeouts(let requests, generalPurposeConnection) = finishAction.request else { + guard case .executeRequestsAndCancelTimeouts(let requests, generalPurposeConnection) = finishAction.request + else { return XCTFail("Unexpected request action: \(finishAction.request)") } guard requests.count == 1, let request = requests.first else { @@ -1154,8 +1445,16 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { // Next the server allows for fewer concurrent streams let fewerMaxStreams = 50 - XCTAssertNoThrow(try connections.newHTTP2ConnectionSettingsReceived(generalPurposeConnection.id, maxConcurrentStreams: fewerMaxStreams)) - let fewerMaxStreamsAction = state.newHTTP2MaxConcurrentStreamsReceived(generalPurposeConnection.id, newMaxStreams: fewerMaxStreams) + XCTAssertNoThrow( + try connections.newHTTP2ConnectionSettingsReceived( + generalPurposeConnection.id, + maxConcurrentStreams: fewerMaxStreams + ) + ) + let fewerMaxStreamsAction = state.newHTTP2MaxConcurrentStreamsReceived( + generalPurposeConnection.id, + newMaxStreams: fewerMaxStreams + ) XCTAssertEqual(fewerMaxStreamsAction.connection, .none) XCTAssertEqual(fewerMaxStreamsAction.request, .none) @@ -1173,7 +1472,8 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { XCTAssertNoThrow(try connections.finishExecution(generalPurposeConnection.id)) let finishAction = state.http2ConnectionStreamClosed(generalPurposeConnection.id) XCTAssertEqual(finishAction.connection, .none) - guard case .executeRequestsAndCancelTimeouts(let requests, generalPurposeConnection) = finishAction.request else { + guard case .executeRequestsAndCancelTimeouts(let requests, generalPurposeConnection) = finishAction.request + else { return XCTFail("Unexpected request action: \(finishAction.request)") } guard requests.count == 1, let request = requests.first else { @@ -1193,7 +1493,10 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { switch remaining { case 1: timeoutTimerScheduled = true - XCTAssertEqual(finishAction.connection, .scheduleTimeoutTimer(generalPurposeConnection.id, on: generalPurposeConnection.eventLoop)) + XCTAssertEqual( + finishAction.connection, + .scheduleTimeoutTimer(generalPurposeConnection.id, on: generalPurposeConnection.eventLoop) + ) XCTAssertNoThrow(try connections.parkConnection(generalPurposeConnection.id)) case 2...50: XCTAssertEqual(finishAction.connection, .none) @@ -1215,7 +1518,7 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { let connection = connections.randomParkedConnection()! XCTAssertNoThrow(try connections.closeConnection(connection)) - let idleTimeoutAction = state.connectionIdleTimeout(connection.id) + let idleTimeoutAction = state.connectionIdleTimeout(connection.id, on: connection.eventLoop) XCTAssertEqual(idleTimeoutAction.connection, .closeConnection(connection, isShutdown: .no)) XCTAssertEqual(idleTimeoutAction.request, .none) @@ -1224,6 +1527,50 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { XCTAssertEqual(state.http2ConnectionClosed(connection.id), .none) } + + func testFailConnectionRacesAgainstConnectionCreationFailed() { + let elg = EmbeddedEventLoopGroup(loops: 4) + defer { XCTAssertNoThrow(try elg.syncShutdownGracefully()) } + + var state = HTTPConnectionPool.StateMachine( + idGenerator: .init(), + maximumConcurrentHTTP1Connections: 2, + retryConnectionEstablishment: true, + preferHTTP1: false, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 + ) + + let mockRequest = MockHTTPScheduableRequest(eventLoop: elg.next()) + let request = HTTPConnectionPool.Request(mockRequest) + + let executeAction = state.executeRequest(request) + XCTAssertEqual(.scheduleRequestTimeout(for: request, on: mockRequest.eventLoop), executeAction.request) + + // 1. connection attempt + guard case .createConnection(let connectionID, on: let connectionEL) = executeAction.connection else { + return XCTFail("Unexpected connection action: \(executeAction.connection)") + } + XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux + + // 2. connection fails – first with closed callback + + XCTAssertEqual(state.http2ConnectionClosed(connectionID), .none) + + // 3. connection fails – with make connection callback + + let action = state.failedToCreateNewConnection( + IOError(errnoCode: -1, reason: "Test failure"), + connectionID: connectionID + ) + XCTAssertEqual(action.request, .none) + guard case .scheduleBackoffTimer(connectionID, _, on: let backoffTimerEL) = action.connection else { + XCTFail("Unexpected connection action: \(action.connection)") + return + } + XCTAssertIdentical(connectionEL, backoffTimerEL) + } + } /// Should be used if you have a value of statically unknown type and want to compare its value to another `Equatable` value. @@ -1235,16 +1582,20 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { func XCTAssertEqualTypeAndValue( _ lhs: @autoclosure () throws -> Left, _ rhs: @autoclosure () throws -> Right, - file: StaticString = #file, + file: StaticString = #filePath, line: UInt = #line ) { - XCTAssertNoThrow(try { - let lhs = try lhs() - let rhs = try rhs() - guard let lhsAsRhs = lhs as? Right else { - XCTFail("could not cast \(lhs) of type \(Right.self) to \(Left.self)") - return - } - XCTAssertEqual(lhsAsRhs, rhs) - }(), file: file, line: line) + XCTAssertNoThrow( + try { + let lhs = try lhs() + let rhs = try rhs() + guard let lhsAsRhs = lhs as? Right else { + XCTFail("could not cast \(lhs) of type \(type(of: lhs)) to \(type(of: rhs))", file: file, line: line) + return + } + XCTAssertEqual(lhsAsRhs, rhs, file: file, line: line) + }(), + file: file, + line: line + ) } diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+ManagerTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+ManagerTests+XCTest.swift deleted file mode 100644 index 93945f63c..000000000 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+ManagerTests+XCTest.swift +++ /dev/null @@ -1,33 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTPConnectionPool+ManagerTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTPConnectionPool_ManagerTests { - static var allTests: [(String, (HTTPConnectionPool_ManagerTests) -> () throws -> Void)] { - return [ - ("testManagerHappyPath", testManagerHappyPath), - ("testShutdownManagerThatHasSeenNoConnections", testShutdownManagerThatHasSeenNoConnections), - ("testExecutingARequestOnAShutdownPoolManager", testExecutingARequestOnAShutdownPoolManager), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+ManagerTests.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+ManagerTests.swift index d84e7f442..724c00b1f 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+ManagerTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+ManagerTests.swift @@ -12,12 +12,14 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient +import Logging import NIOCore import NIOHTTP1 import NIOPosix import XCTest +@testable import AsyncHTTPClient + class HTTPConnectionPool_ManagerTests: XCTestCase { func testManagerHappyPath() { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 4) @@ -49,15 +51,17 @@ class HTTPConnectionPool_ManagerTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(5), - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(5), + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } @@ -105,15 +109,17 @@ class HTTPConnectionPool_ManagerTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(5), - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(5), + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+RequestQueueTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+RequestQueueTests+XCTest.swift deleted file mode 100644 index 2511ba267..000000000 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+RequestQueueTests+XCTest.swift +++ /dev/null @@ -1,31 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTPConnectionPool+RequestQueueTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTPConnectionPool_RequestQueueTests { - static var allTests: [(String, (HTTPConnectionPool_RequestQueueTests) -> () throws -> Void)] { - return [ - ("testCountAndIsEmptyWorks", testCountAndIsEmptyWorks), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+RequestQueueTests.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+RequestQueueTests.swift index f8d6044cd..4f4bbd785 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+RequestQueueTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+RequestQueueTests.swift @@ -12,7 +12,6 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOCore import NIOEmbedded @@ -20,6 +19,8 @@ import NIOHTTP1 import NIOSSL import XCTest +@testable import AsyncHTTPClient + class HTTPConnectionPool_RequestQueueTests: XCTestCase { func testCountAndIsEmptyWorks() { var queue = HTTPConnectionPool.RequestQueue() @@ -82,7 +83,7 @@ class HTTPConnectionPool_RequestQueueTests: XCTestCase { } } -private class MockScheduledRequest: HTTPSchedulableRequest { +final private class MockScheduledRequest: HTTPSchedulableRequest { let requiredEventLoop: EventLoop? init(requiredEventLoop: EventLoop?) { diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+StateTestUtils.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+StateTestUtils.swift index cb67837d7..bd9752d5d 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+StateTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+StateTestUtils.swift @@ -12,28 +12,30 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient +import Atomics import Dispatch import NIOConcurrencyHelpers import NIOCore import NIOEmbedded +@testable import AsyncHTTPClient + /// An `EventLoopGroup` of `EmbeddedEventLoop`s. final class EmbeddedEventLoopGroup: EventLoopGroup { private let loops: [EmbeddedEventLoop] - private let index = NIOAtomic.makeAtomic(value: 0) + private let index = ManagedAtomic(0) internal init(loops: Int) { self.loops = (0.. EventLoop { - let index: Int = self.index.add(1) + let index: Int = self.index.loadThenWrappingIncrement(ordering: .relaxed) return self.loops[index % self.loops.count] } internal func makeIterator() -> EventLoopIterator { - return EventLoopIterator(self.loops) + EventLoopIterator(self.loops) } internal func shutdownGracefully(queue: DispatchQueue, _ callback: @escaping (Error?) -> Void) { @@ -55,7 +57,7 @@ final class EmbeddedEventLoopGroup: EventLoopGroup { extension HTTPConnectionPool.Request: Equatable { public static func == (lhs: Self, rhs: Self) -> Bool { - return lhs.id == rhs.id + lhs.id == rhs.id } } @@ -77,15 +79,24 @@ extension HTTPConnectionPool.StateMachine.ConnectionAction: Equatable { switch (lhs, rhs) { case (.createConnection(let lhsConnID, on: let lhsEL), .createConnection(let rhsConnID, on: let rhsEL)): return lhsConnID == rhsConnID && lhsEL === rhsEL - case (.scheduleBackoffTimer(let lhsConnID, let lhsBackoff, on: let lhsEL), .scheduleBackoffTimer(let rhsConnID, let rhsBackoff, on: let rhsEL)): + case ( + .scheduleBackoffTimer(let lhsConnID, let lhsBackoff, on: let lhsEL), + .scheduleBackoffTimer(let rhsConnID, let rhsBackoff, on: let rhsEL) + ): return lhsConnID == rhsConnID && lhsBackoff == rhsBackoff && lhsEL === rhsEL case (.scheduleTimeoutTimer(let lhsConnID, on: let lhsEL), .scheduleTimeoutTimer(let rhsConnID, on: let rhsEL)): return lhsConnID == rhsConnID && lhsEL === rhsEL case (.cancelTimeoutTimer(let lhsConnID), .cancelTimeoutTimer(let rhsConnID)): return lhsConnID == rhsConnID - case (.closeConnection(let lhsConn, isShutdown: let lhsShut), .closeConnection(let rhsConn, isShutdown: let rhsShut)): + case ( + .closeConnection(let lhsConn, isShutdown: let lhsShut), + .closeConnection(let rhsConn, isShutdown: let rhsShut) + ): return lhsConn == rhsConn && lhsShut == rhsShut - case (.cleanupConnections(let lhsContext, isShutdown: let lhsShut), .cleanupConnections(let rhsContext, isShutdown: let rhsShut)): + case ( + .cleanupConnections(let lhsContext, isShutdown: let lhsShut), + .cleanupConnections(let rhsContext, isShutdown: let rhsShut) + ): return lhsContext == rhsContext && lhsShut == rhsShut case ( .migration( @@ -99,12 +110,13 @@ extension HTTPConnectionPool.StateMachine.ConnectionAction: Equatable { let rhsScheduleTimeout ) ): - return lhsCreateConnections.elementsEqual(rhsCreateConnections, by: { - $0.0 == $1.0 && $0.1 === $1.1 - }) && - lhsCloseConnections == rhsCloseConnections && - lhsScheduleTimeout?.0 == rhsScheduleTimeout?.0 && - lhsScheduleTimeout?.1 === rhsScheduleTimeout?.1 + return lhsCreateConnections.elementsEqual( + rhsCreateConnections, + by: { + $0.0 == $1.0 && $0.1 === $1.1 + } + ) && lhsCloseConnections == rhsCloseConnections && lhsScheduleTimeout?.0 == rhsScheduleTimeout?.0 + && lhsScheduleTimeout?.1 === rhsScheduleTimeout?.1 case (.none, .none): return true default: @@ -116,18 +128,28 @@ extension HTTPConnectionPool.StateMachine.ConnectionAction: Equatable { extension HTTPConnectionPool.StateMachine.RequestAction: Equatable { public static func == (lhs: Self, rhs: Self) -> Bool { switch (lhs, rhs) { - case (.executeRequest(let lhsReq, let lhsConn, let lhsReqID), .executeRequest(let rhsReq, let rhsConn, let rhsReqID)): + case ( + .executeRequest(let lhsReq, let lhsConn, let lhsReqID), + .executeRequest(let rhsReq, let rhsConn, let rhsReqID) + ): return lhsReq == rhsReq && lhsConn == rhsConn && lhsReqID == rhsReqID - case (.executeRequestsAndCancelTimeouts(let lhsReqs, let lhsConn), .executeRequestsAndCancelTimeouts(let rhsReqs, let rhsConn)): + case ( + .executeRequestsAndCancelTimeouts(let lhsReqs, let lhsConn), + .executeRequestsAndCancelTimeouts(let rhsReqs, let rhsConn) + ): return lhsReqs.elementsEqual(rhsReqs, by: { $0 == $1 }) && lhsConn == rhsConn - case (.failRequest(let lhsReq, _, cancelTimeout: let lhsReqID), .failRequest(let rhsReq, _, cancelTimeout: let rhsReqID)): + case ( + .failRequest(let lhsReq, _, cancelTimeout: let lhsReqID), + .failRequest(let rhsReq, _, cancelTimeout: let rhsReqID) + ): return lhsReq == rhsReq && lhsReqID == rhsReqID case (.failRequestsAndCancelTimeouts(let lhsReqs, _), .failRequestsAndCancelTimeouts(let rhsReqs, _)): return lhsReqs.elementsEqual(rhsReqs, by: { $0 == $1 }) - case (.scheduleRequestTimeout(for: let lhsReq, on: let lhsEL), .scheduleRequestTimeout(for: let rhsReq, on: let rhsEL)): + case ( + .scheduleRequestTimeout(for: let lhsReq, on: let lhsEL), + .scheduleRequestTimeout(for: let rhsReq, on: let rhsEL) + ): return lhsReq == rhsReq && lhsEL === rhsEL - case (.cancelRequestTimeout(let lhsReqID), .cancelRequestTimeout(let rhsReqID)): - return lhsReqID == rhsReqID case (.none, .none): return true default: @@ -147,7 +169,10 @@ extension HTTPConnectionPool.HTTP2StateMachine.EstablishedConnectionAction: Equa switch (lhs, rhs) { case (.scheduleTimeoutTimer(let lhsConnID, on: let lhsEL), .scheduleTimeoutTimer(let rhsConnID, on: let rhsEL)): return lhsConnID == rhsConnID && lhsEL === rhsEL - case (.closeConnection(let lhsConn, isShutdown: let lhsShut), .closeConnection(let rhsConn, isShutdown: let rhsShut)): + case ( + .closeConnection(let lhsConn, isShutdown: let lhsShut), + .closeConnection(let rhsConn, isShutdown: let rhsShut) + ): return lhsConn == rhsConn && lhsShut == rhsShut case (.none, .none): return true diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPoolTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPoolTests+XCTest.swift deleted file mode 100644 index acdc0ab26..000000000 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPoolTests+XCTest.swift +++ /dev/null @@ -1,39 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTPConnectionPoolTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTPConnectionPoolTests { - static var allTests: [(String, (HTTPConnectionPoolTests) -> () throws -> Void)] { - return [ - ("testOnlyOneConnectionIsUsedForSubSequentRequests", testOnlyOneConnectionIsUsedForSubSequentRequests), - ("testConnectionsForEventLoopRequirementsAreClosed", testConnectionsForEventLoopRequirementsAreClosed), - ("testConnectionPoolGrowsToMaxConcurrentConnections", testConnectionPoolGrowsToMaxConcurrentConnections), - ("testConnectionCreationIsRetriedUntilRequestIsFailed", testConnectionCreationIsRetriedUntilRequestIsFailed), - ("testConnectionCreationIsRetriedUntilPoolIsShutdown", testConnectionCreationIsRetriedUntilPoolIsShutdown), - ("testConnectionCreationIsRetriedUntilRequestIsCancelled", testConnectionCreationIsRetriedUntilRequestIsCancelled), - ("testConnectionShutdownIsCalledOnActiveConnections", testConnectionShutdownIsCalledOnActiveConnections), - ("testConnectionPoolStressResistanceHTTP1", testConnectionPoolStressResistanceHTTP1), - ("testBackoffBehavesSensibly", testBackoffBehavesSensibly), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPoolTests.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPoolTests.swift index 60e5077ee..a40703456 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPoolTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPoolTests.swift @@ -12,13 +12,14 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOCore import NIOHTTP1 import NIOPosix import XCTest +@testable import AsyncHTTPClient + class HTTPConnectionPoolTests: XCTestCase { func testOnlyOneConnectionIsUsedForSubSequentRequests() { let httpBin = HTTPBin() @@ -53,15 +54,17 @@ class HTTPConnectionPoolTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoop, logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .distantFuture, - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoop, logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .distantFuture, + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } @@ -82,7 +85,6 @@ class HTTPConnectionPoolTests: XCTestCase { let request = try! HTTPClient.Request(url: "http://localhost:\(httpBin.port)") let poolDelegate = TestDelegate(eventLoop: eventLoop) - let pool = HTTPConnectionPool( eventLoopGroup: eventLoopGroup, sslContextCache: .init(), @@ -93,6 +95,74 @@ class HTTPConnectionPoolTests: XCTestCase { idGenerator: .init(), backgroundActivityLogger: .init(label: "test") ) + defer { + pool.shutdown() + XCTAssertNoThrow(try poolDelegate.future.wait()) + XCTAssertNoThrow(try eventLoop.scheduleTask(in: .milliseconds(100)) {}.futureResult.wait()) + XCTAssertEqual(httpBin.activeConnections, 0) + // Since we would migrate from h2 -> h1, which creates a general purpose connection + // for every connection in .starting state, after the first request which will + // be serviced by an overflow connection, the rest of requests will use the general + // purpose connection since they are all on the same event loop. + // Hence we will only create 1 overflow connection and 1 general purpose connection. + XCTAssertEqual(httpBin.createdConnections, 2) + } + + XCTAssertEqual(httpBin.createdConnections, 0) + + for _ in 0..<10 { + var maybeRequest: HTTPClient.Request? + var maybeRequestBag: RequestBag? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .init( + .testOnly_exact(channelOn: eventLoopGroup.next(), delegateOn: eventLoopGroup.next()) + ), + task: .init(eventLoop: eventLoop, logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .distantFuture, + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) + + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } + + pool.executeRequest(requestBag) + XCTAssertNoThrow(try requestBag.task.futureResult.wait()) + + // Flakiness Alert: We check <= and >= instead of == + // While migration from h2 -> h1, one general purpose and one over flow connection + // will be created, there's no guarantee as to whether the request is executed + // after both are created. + XCTAssertGreaterThanOrEqual(httpBin.createdConnections, 1) + XCTAssertLessThanOrEqual(httpBin.createdConnections, 2) + } + } + + func testConnectionsForEventLoopRequirementsAreClosedH1Only() { + let httpBin = HTTPBin() + defer { XCTAssertNoThrow(try httpBin.shutdown()) } + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) + let eventLoop = eventLoopGroup.next() + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + + let request = try! HTTPClient.Request(url: "http://localhost:\(httpBin.port)") + let poolDelegate = TestDelegate(eventLoop: eventLoop) + var configuration = HTTPClient.Configuration() + configuration.httpVersion = .http1Only + let pool = HTTPConnectionPool( + eventLoopGroup: eventLoopGroup, + sslContextCache: .init(), + tlsConfiguration: .none, + clientConfiguration: configuration, + key: .init(request), + delegate: poolDelegate, + idGenerator: .init(), + backgroundActivityLogger: .init(label: "test") + ) defer { pool.shutdown() XCTAssertNoThrow(try poolDelegate.future.wait()) @@ -107,15 +177,19 @@ class HTTPConnectionPoolTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .init(.testOnly_exact(channelOn: eventLoopGroup.next(), delegateOn: eventLoopGroup.next())), - task: .init(eventLoop: eventLoop, logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .distantFuture, - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .init( + .testOnly_exact(channelOn: eventLoopGroup.next(), delegateOn: eventLoopGroup.next()) + ), + task: .init(eventLoop: eventLoop, logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .distantFuture, + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } @@ -162,15 +236,17 @@ class HTTPConnectionPoolTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .distantFuture, - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .distantFuture, + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } @@ -216,15 +292,17 @@ class HTTPConnectionPoolTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(5), - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(5), + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } @@ -264,15 +342,17 @@ class HTTPConnectionPoolTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(5), - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(5), + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } @@ -320,21 +400,23 @@ class HTTPConnectionPoolTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(5), - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(5), + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } pool.executeRequest(requestBag) XCTAssertNoThrow(try eventLoop.scheduleTask(in: .seconds(1)) {}.futureResult.wait()) - requestBag.cancel() + requestBag.fail(HTTPClientError.cancelled) XCTAssertThrowsError(try requestBag.task.futureResult.wait()) { XCTAssertEqual($0 as? HTTPClientError, .cancelled) @@ -366,15 +448,17 @@ class HTTPConnectionPoolTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/wait")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(5), - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(5), + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } @@ -419,22 +503,24 @@ class HTTPConnectionPoolTests: XCTestCase { let dispatchGroup = DispatchGroup() for workerID in 0..? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: url)) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoopGroup.next(), logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(5), - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoopGroup.next(), logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(5), + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } pool.executeRequest(requestBag) @@ -458,7 +544,10 @@ class HTTPConnectionPoolTests: XCTestCase { var backoff = HTTPConnectionPool.calculateBackoff(failedAttempt: 1) // The value should be 100ms±3ms - XCTAssertLessThanOrEqual((backoff - .milliseconds(100)).nanoseconds.magnitude, TimeAmount.milliseconds(3).nanoseconds.magnitude) + XCTAssertLessThanOrEqual( + (backoff - .milliseconds(100)).nanoseconds.magnitude, + TimeAmount.milliseconds(3).nanoseconds.magnitude + ) // Should always increase // We stop when we get within the jitter of 60s, which is 1.8s @@ -474,7 +563,8 @@ class HTTPConnectionPoolTests: XCTestCase { // Ok, now we should be able to do a hundred increments, and always hit 60s, plus or minus 1.8s of jitter. for offset in 0..<100 { XCTAssertLessThanOrEqual( - (HTTPConnectionPool.calculateBackoff(failedAttempt: attempt + offset) - .seconds(60)).nanoseconds.magnitude, + (HTTPConnectionPool.calculateBackoff(failedAttempt: attempt + offset) - .seconds(60)).nanoseconds + .magnitude, TimeAmount.milliseconds(1800).nanoseconds.magnitude ) } diff --git a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests+XCTest.swift deleted file mode 100644 index b54865fd8..000000000 --- a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests+XCTest.swift +++ /dev/null @@ -1,66 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// HTTPRequestStateMachineTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension HTTPRequestStateMachineTests { - static var allTests: [(String, (HTTPRequestStateMachineTests) -> () throws -> Void)] { - return [ - ("testSimpleGETRequest", testSimpleGETRequest), - ("testPOSTRequestWithWriterBackpressure", testPOSTRequestWithWriterBackpressure), - ("testPOSTContentLengthIsTooLong", testPOSTContentLengthIsTooLong), - ("testPOSTContentLengthIsTooShort", testPOSTContentLengthIsTooShort), - ("testRequestBodyStreamIsCancelledIfServerRespondsWith301", testRequestBodyStreamIsCancelledIfServerRespondsWith301), - ("testRequestBodyStreamIsCancelledIfServerRespondsWith301WhileWriteBackpressure", testRequestBodyStreamIsCancelledIfServerRespondsWith301WhileWriteBackpressure), - ("testRequestBodyStreamIsContinuedIfServerRespondsWith200", testRequestBodyStreamIsContinuedIfServerRespondsWith200), - ("testRequestBodyStreamIsContinuedIfServerSendHeadWithStatus200", testRequestBodyStreamIsContinuedIfServerSendHeadWithStatus200), - ("testRequestIsFailedIfRequestBodySizeIsWrongEvenAfterServerRespondedWith200", testRequestIsFailedIfRequestBodySizeIsWrongEvenAfterServerRespondedWith200), - ("testRequestIsFailedIfRequestBodySizeIsWrongEvenAfterServerSendHeadWithStatus200", testRequestIsFailedIfRequestBodySizeIsWrongEvenAfterServerSendHeadWithStatus200), - ("testRequestIsNotSendUntilChannelIsWritable", testRequestIsNotSendUntilChannelIsWritable), - ("testConnectionBecomesInactiveWhileWaitingForWritable", testConnectionBecomesInactiveWhileWaitingForWritable), - ("testResponseReadingWithBackpressure", testResponseReadingWithBackpressure), - ("testChannelReadCompleteTriggersButNoBodyDataWasReceivedSoFar", testChannelReadCompleteTriggersButNoBodyDataWasReceivedSoFar), - ("testResponseReadingWithBackpressureEndOfResponseAllowsReadEventsToTriggerDirectly", testResponseReadingWithBackpressureEndOfResponseAllowsReadEventsToTriggerDirectly), - ("testCancellingARequestInStateInitializedKeepsTheConnectionAlive", testCancellingARequestInStateInitializedKeepsTheConnectionAlive), - ("testCancellingARequestBeforeBeingSendKeepsTheConnectionAlive", testCancellingARequestBeforeBeingSendKeepsTheConnectionAlive), - ("testConnectionBecomesWritableBeforeFirstRequest", testConnectionBecomesWritableBeforeFirstRequest), - ("testCancellingARequestThatIsSent", testCancellingARequestThatIsSent), - ("testRemoteSuddenlyClosesTheConnection", testRemoteSuddenlyClosesTheConnection), - ("testReadTimeoutLeadsToFailureWithEverythingAfterBeingIgnored", testReadTimeoutLeadsToFailureWithEverythingAfterBeingIgnored), - ("testResponseWithStatus1XXAreIgnored", testResponseWithStatus1XXAreIgnored), - ("testReadTimeoutThatFiresToLateIsIgnored", testReadTimeoutThatFiresToLateIsIgnored), - ("testCancellationThatIsInvokedToLateIsIgnored", testCancellationThatIsInvokedToLateIsIgnored), - ("testErrorWhileRunningARequestClosesTheStream", testErrorWhileRunningARequestClosesTheStream), - ("testCanReadHTTP1_0ResponseWithoutBody", testCanReadHTTP1_0ResponseWithoutBody), - ("testCanReadHTTP1_0ResponseWithBody", testCanReadHTTP1_0ResponseWithBody), - ("testFailHTTP1_0RequestThatIsStillUploading", testFailHTTP1_0RequestThatIsStillUploading), - ("testFailHTTP1RequestWithoutContentLengthWithNIOSSLErrorUncleanShutdown", testFailHTTP1RequestWithoutContentLengthWithNIOSSLErrorUncleanShutdown), - ("testNIOSSLErrorUncleanShutdownShouldBeTreatedAsRemoteConnectionCloseWhileInWaitingForHeadState", testNIOSSLErrorUncleanShutdownShouldBeTreatedAsRemoteConnectionCloseWhileInWaitingForHeadState), - ("testArbitraryErrorShouldBeTreatedAsARequestFailureWhileInWaitingForHeadState", testArbitraryErrorShouldBeTreatedAsARequestFailureWhileInWaitingForHeadState), - ("testFailHTTP1RequestWithContentLengthWithNIOSSLErrorUncleanShutdownButIgnoreIt", testFailHTTP1RequestWithContentLengthWithNIOSSLErrorUncleanShutdownButIgnoreIt), - ("testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForDemand", testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForDemand), - ("testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForRead", testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForRead), - ("testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForReadAndDemand", testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForReadAndDemand), - ("testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForReadAndDemandMultipleTimes", testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForReadAndDemandMultipleTimes), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift index ab55345c9..8fe879745 100644 --- a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift @@ -12,21 +12,29 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore +import NIOEmbedded import NIOHTTP1 import NIOSSL import XCTest +@testable import AsyncHTTPClient + class HTTPRequestStateMachineTests: XCTestCase { func testSimpleGETRequest() { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init([responseBody]))) @@ -35,32 +43,47 @@ class HTTPRequestStateMachineTests: XCTestCase { func testPOSTRequestWithWriterBackpressure() { var state = HTTPRequestStateMachine(isChannelWritable: true) - let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "4")])) + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: "/", + headers: HTTPHeaders([("content-length", "4")]) + ) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: true, startIdleTimer: false) + ) let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0])) let part1 = IOData.byteBuffer(ByteBuffer(bytes: [1])) let part2 = IOData.byteBuffer(ByteBuffer(bytes: [2])) let part3 = IOData.byteBuffer(ByteBuffer(bytes: [3])) - XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) - XCTAssertEqual(state.requestStreamPartReceived(part1), .sendBodyPart(part1)) + XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) + XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) // oh the channel reports... we should slow down producing... XCTAssertEqual(state.writabilityChanged(writable: false), .pauseRequestBodyStream) // but we issued a .produceMoreRequestBodyData before... Thus, we must accept more produced // data - XCTAssertEqual(state.requestStreamPartReceived(part2), .sendBodyPart(part2)) + XCTAssertEqual(state.requestStreamPartReceived(part2, promise: nil), .sendBodyPart(part2, nil)) // however when we have put the data on the channel, we should not issue further // .produceMoreRequestBodyData events // once we receive a writable event again, we can allow the producer to produce more data XCTAssertEqual(state.writabilityChanged(writable: true), .resumeRequestBodyStream) - XCTAssertEqual(state.requestStreamPartReceived(part3), .sendBodyPart(part3)) - XCTAssertEqual(state.requestStreamFinished(), .sendRequestEnd) + XCTAssertEqual(state.requestStreamPartReceived(part3, promise: nil), .sendBodyPart(part3, nil)) + XCTAssertEqual(state.requestStreamFinished(promise: nil), .sendRequestEnd(nil)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init([responseBody]))) @@ -69,14 +92,25 @@ class HTTPRequestStateMachineTests: XCTestCase { func testPOSTContentLengthIsTooLong() { var state = HTTPRequestStateMachine(isChannelWritable: true) - let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "4")])) + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: "/", + headers: HTTPHeaders([("content-length", "4")]) + ) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) let part1 = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) - XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) + XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) - state.requestStreamPartReceived(part1).assertFailRequest(HTTPClientError.bodyLengthMismatch, .close) + state.requestStreamPartReceived(part1, promise: nil).assertFailRequest( + HTTPClientError.bodyLengthMismatch, + .close(nil) + ) // if another error happens the new one is ignored XCTAssertEqual(state.errorHappened(HTTPClientError.remoteConnectionClosed), .wait) @@ -84,140 +118,257 @@ class HTTPRequestStateMachineTests: XCTestCase { func testPOSTContentLengthIsTooShort() { var state = HTTPRequestStateMachine(isChannelWritable: true) - let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "8")])) + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: "/", + headers: HTTPHeaders([("content-length", "8")]) + ) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(8)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) - XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) + XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) - state.requestStreamFinished().assertFailRequest(HTTPClientError.bodyLengthMismatch, .close) + state.requestStreamFinished(promise: nil).assertFailRequest(HTTPClientError.bodyLengthMismatch, .close(nil)) } func testRequestBodyStreamIsCancelledIfServerRespondsWith301() { var state = HTTPRequestStateMachine(isChannelWritable: true) - let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "12")])) + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: "/", + headers: HTTPHeaders([("content-length", "12")]) + ) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: true, startIdleTimer: false) + ) let part = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) - XCTAssertEqual(state.requestStreamPartReceived(part), .sendBodyPart(part)) + XCTAssertEqual(state.requestStreamPartReceived(part, promise: nil), .sendBodyPart(part, nil)) // response is coming before having send all data let responseHead = HTTPResponseHead(version: .http1_1, status: .movedPermanently) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: true)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: true) + ) XCTAssertEqual(state.writabilityChanged(writable: false), .wait) XCTAssertEqual(state.writabilityChanged(writable: true), .wait) - XCTAssertEqual(state.requestStreamPartReceived(part), .wait, - "Expected to drop all stream data after having received a response head, with status >= 300") + XCTAssertEqual( + state.requestStreamPartReceived(part, promise: nil), + .failSendBodyPart(HTTPClientError.requestStreamCancelled, nil), + "Expected to drop all stream data after having received a response head, with status >= 300" + ) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, .init())) - XCTAssertEqual(state.requestStreamPartReceived(part), .wait, - "Expected to drop all stream data after having received a response head, with status >= 300") + XCTAssertEqual( + state.requestStreamPartReceived(part, promise: nil), + .failSendBodyPart(HTTPClientError.requestStreamCancelled, nil), + "Expected to drop all stream data after having received a response head, with status >= 300" + ) + + XCTAssertEqual( + state.requestStreamFinished(promise: nil), + .failSendStreamFinished(HTTPClientError.requestStreamCancelled, nil), + "Expected to drop all stream data after having received a response head, with status >= 300" + ) + } - XCTAssertEqual(state.requestStreamFinished(), .wait, - "Expected to drop all stream data after having received a response head, with status >= 300") + func testStreamPartReceived_whenCancelled() { + var state = HTTPRequestStateMachine(isChannelWritable: false) + let part = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) + + XCTAssertEqual(state.requestCancelled(), .failRequest(HTTPClientError.cancelled, .none)) + XCTAssertEqual( + state.requestStreamPartReceived(part, promise: nil), + .failSendBodyPart(HTTPClientError.cancelled, nil), + "Expected to drop all stream data after having received a response head, with status >= 300" + ) } func testRequestBodyStreamIsCancelledIfServerRespondsWith301WhileWriteBackpressure() { var state = HTTPRequestStateMachine(isChannelWritable: true) - let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "12")])) + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: "/", + headers: HTTPHeaders([("content-length", "12")]) + ) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: true, startIdleTimer: false) + ) let part = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) - XCTAssertEqual(state.requestStreamPartReceived(part), .sendBodyPart(part)) + XCTAssertEqual(state.requestStreamPartReceived(part, promise: nil), .sendBodyPart(part, nil)) XCTAssertEqual(state.writabilityChanged(writable: false), .pauseRequestBodyStream) // response is coming before having send all data let responseHead = HTTPResponseHead(version: .http1_1, status: .movedPermanently) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.writabilityChanged(writable: true), .wait) - XCTAssertEqual(state.requestStreamPartReceived(part), .wait, - "Expected to drop all stream data after having received a response head, with status >= 300") + XCTAssertEqual( + state.requestStreamPartReceived(part, promise: nil), + .failSendBodyPart(HTTPClientError.requestStreamCancelled, nil), + "Expected to drop all stream data after having received a response head, with status >= 300" + ) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, .init())) - XCTAssertEqual(state.requestStreamPartReceived(part), .wait, - "Expected to drop all stream data after having received a response head, with status >= 300") - - XCTAssertEqual(state.requestStreamFinished(), .wait, - "Expected to drop all stream data after having received a response head, with status >= 300") + XCTAssertEqual( + state.requestStreamPartReceived(part, promise: nil), + .failSendBodyPart(HTTPClientError.requestStreamCancelled, nil), + "Expected to drop all stream data after having received a response head, with status >= 300" + ) + + XCTAssertEqual( + state.requestStreamFinished(promise: nil), + .failSendStreamFinished(HTTPClientError.requestStreamCancelled, nil), + "Expected to drop all stream data after having received a response head, with status >= 300" + ) } func testRequestBodyStreamIsContinuedIfServerRespondsWith200() { var state = HTTPRequestStateMachine(isChannelWritable: true) - let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "12")])) + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: "/", + headers: HTTPHeaders([("content-length", "12")]) + ) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) let part0 = IOData.byteBuffer(ByteBuffer(bytes: 0...3)) - XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) + XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) // response is coming before having send all data let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseBodyParts(.init())) let part1 = IOData.byteBuffer(ByteBuffer(bytes: 4...7)) - XCTAssertEqual(state.requestStreamPartReceived(part1), .sendBodyPart(part1)) + XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) let part2 = IOData.byteBuffer(ByteBuffer(bytes: 8...11)) - XCTAssertEqual(state.requestStreamPartReceived(part2), .sendBodyPart(part2)) - XCTAssertEqual(state.requestStreamFinished(), .succeedRequest(.sendRequestEnd, .init())) + XCTAssertEqual(state.requestStreamPartReceived(part2, promise: nil), .sendBodyPart(part2, nil)) + XCTAssertEqual(state.requestStreamFinished(promise: nil), .succeedRequest(.sendRequestEnd(nil), .init())) + + XCTAssertEqual( + state.requestStreamPartReceived(part2, promise: nil), + .failSendBodyPart(HTTPClientError.requestStreamCancelled, nil) + ) } func testRequestBodyStreamIsContinuedIfServerSendHeadWithStatus200() { var state = HTTPRequestStateMachine(isChannelWritable: true) - let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "12")])) + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: "/", + headers: HTTPHeaders([("content-length", "12")]) + ) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) let part0 = IOData.byteBuffer(ByteBuffer(bytes: 0...3)) - XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) + XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) // response is coming before having send all data let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let part1 = IOData.byteBuffer(ByteBuffer(bytes: 4...7)) - XCTAssertEqual(state.requestStreamPartReceived(part1), .sendBodyPart(part1)) + XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) let part2 = IOData.byteBuffer(ByteBuffer(bytes: 8...11)) - XCTAssertEqual(state.requestStreamPartReceived(part2), .sendBodyPart(part2)) - XCTAssertEqual(state.requestStreamFinished(), .sendRequestEnd) + XCTAssertEqual(state.requestStreamPartReceived(part2, promise: nil), .sendBodyPart(part2, nil)) + XCTAssertEqual(state.requestStreamFinished(promise: nil), .sendRequestEnd(nil)) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init())) } func testRequestIsFailedIfRequestBodySizeIsWrongEvenAfterServerRespondedWith200() { var state = HTTPRequestStateMachine(isChannelWritable: true) - let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "12")])) + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: "/", + headers: HTTPHeaders([("content-length", "12")]) + ) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) let part0 = IOData.byteBuffer(ByteBuffer(bytes: 0...3)) - XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) + XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) // response is coming before having send all data let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseBodyParts(.init())) let part1 = IOData.byteBuffer(ByteBuffer(bytes: 4...7)) - XCTAssertEqual(state.requestStreamPartReceived(part1), .sendBodyPart(part1)) - state.requestStreamFinished().assertFailRequest(HTTPClientError.bodyLengthMismatch, .close) + XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) + state.requestStreamFinished(promise: nil).assertFailRequest(HTTPClientError.bodyLengthMismatch, .close(nil)) XCTAssertEqual(state.channelInactive(), .wait) } func testRequestIsFailedIfRequestBodySizeIsWrongEvenAfterServerSendHeadWithStatus200() { var state = HTTPRequestStateMachine(isChannelWritable: true) - let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "12")])) + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: "/", + headers: HTTPHeaders([("content-length", "12")]) + ) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) let part0 = IOData.byteBuffer(ByteBuffer(bytes: 0...3)) - XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) + XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) // response is coming before having send all data let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let part1 = IOData.byteBuffer(ByteBuffer(bytes: 4...7)) - XCTAssertEqual(state.requestStreamPartReceived(part1), .sendBodyPart(part1)) - state.requestStreamFinished().assertFailRequest(HTTPClientError.bodyLengthMismatch, .close) + XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) + state.requestStreamFinished(promise: nil).assertFailRequest(HTTPClientError.bodyLengthMismatch, .close(nil)) XCTAssertEqual(state.channelRead(.end(nil)), .wait) } @@ -227,10 +378,13 @@ class HTTPRequestStateMachineTests: XCTestCase { let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .wait) XCTAssertEqual(state.read(), .read) - XCTAssertEqual(state.writabilityChanged(writable: true), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual(state.writabilityChanged(writable: true), .sendRequestHead(requestHead, sendEnd: true)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init([responseBody]))) @@ -249,10 +403,20 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) - - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) + + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "12")]) + ) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let part0 = ByteBuffer(bytes: 0...3) let part1 = ByteBuffer(bytes: 4...7) let part2 = ByteBuffer(bytes: 8...11) @@ -276,10 +440,20 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) - - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) + + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "12")]) + ) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let part0 = ByteBuffer(bytes: 0...3) let part1 = ByteBuffer(bytes: 4...7) let part2 = ByteBuffer(bytes: 8...11) @@ -303,10 +477,20 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) - - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) + + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "12")]) + ) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let part0 = ByteBuffer(bytes: 0...3) let part1 = ByteBuffer(bytes: 4...7) let part2 = ByteBuffer(bytes: 8...11) @@ -321,7 +505,11 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.read(), .read) XCTAssertEqual(state.channelRead(.body(part2)), .wait) XCTAssertEqual(state.read(), .read, "Calling `read` while we wait for a channelReadComplete doesn't crash") - XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait, "Calling `demandMoreResponseBodyParts` while we wait for a channelReadComplete doesn't crash") + XCTAssertEqual( + state.demandMoreResponseBodyParts(), + .wait, + "Calling `demandMoreResponseBodyParts` while we wait for a channelReadComplete doesn't crash" + ) XCTAssertEqual(state.channelReadComplete(), .forwardResponseBodyParts(.init([part2]))) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.read(), .read) @@ -350,11 +538,17 @@ class HTTPRequestStateMachineTests: XCTestCase { // --- sending request let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) // --- receiving response let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["content-length": "4"]) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init([responseBody]))) @@ -365,30 +559,54 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) - state.requestCancelled().assertFailRequest(HTTPClientError.cancelled, .close) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) + state.requestCancelled().assertFailRequest(HTTPClientError.cancelled, .close(nil)) } func testRemoteSuddenlyClosesTheConnection() { var state = HTTPRequestStateMachine(isChannelWritable: true) - let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/", headers: .init([("content-length", "4")])) + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .GET, + uri: "/", + headers: .init([("content-length", "4")]) + ) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) - state.requestCancelled().assertFailRequest(HTTPClientError.cancelled, .close) - XCTAssertEqual(state.requestStreamPartReceived(.byteBuffer(.init(bytes: 1...3))), .wait) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) + state.requestCancelled().assertFailRequest(HTTPClientError.cancelled, .close(nil)) + XCTAssertEqual( + state.requestStreamPartReceived(.byteBuffer(.init(bytes: 1...3)), promise: nil), + .failSendBodyPart(HTTPClientError.cancelled, nil) + ) } func testReadTimeoutLeadsToFailureWithEverythingAfterBeingIgnored() { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) - - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) + + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "12")]) + ) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let part0 = ByteBuffer(bytes: 0...3) XCTAssertEqual(state.channelRead(.body(part0)), .wait) - state.idleReadTimeoutTriggered().assertFailRequest(HTTPClientError.readTimeout, .close) + state.idleReadTimeoutTriggered().assertFailRequest(HTTPClientError.readTimeout, .close(nil)) XCTAssertEqual(state.channelRead(.body(ByteBuffer(bytes: 4...7))), .wait) XCTAssertEqual(state.channelRead(.body(ByteBuffer(bytes: 8...11))), .wait) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) @@ -399,13 +617,19 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let continueHead = HTTPResponseHead(version: .http1_1, status: .continue) XCTAssertEqual(state.channelRead(.head(continueHead)), .wait) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init())) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) @@ -415,10 +639,16 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init())) XCTAssertEqual(state.idleReadTimeoutTriggered(), .wait, "A read timeout that fires to late must be ignored") } @@ -427,10 +657,16 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init())) XCTAssertEqual(state.requestCancelled(), .wait, "A cancellation that happens to late is ignored") } @@ -439,9 +675,15 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) - - state.errorHappened(HTTPParserError.invalidChunkSize).assertFailRequest(HTTPParserError.invalidChunkSize, .close) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) + + state.errorHappened(HTTPParserError.invalidChunkSize).assertFailRequest( + HTTPParserError.invalidChunkSize, + .close(nil) + ) XCTAssertEqual(state.requestCancelled(), .wait, "A cancellation that happens to late is ignored") } @@ -449,10 +691,16 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_0, status: .internalServerError) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) @@ -465,11 +713,17 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_0, status: .internalServerError) let body = ByteBuffer(string: "foo bar") - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) @@ -483,19 +737,28 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .stream) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) let part1: ByteBuffer = .init(string: "foo") - XCTAssertEqual(state.requestStreamPartReceived(.byteBuffer(part1)), .sendBodyPart(.byteBuffer(part1))) + XCTAssertEqual( + state.requestStreamPartReceived(.byteBuffer(part1), promise: nil), + .sendBodyPart(.byteBuffer(part1), nil) + ) let responseHead = HTTPResponseHead(version: .http1_0, status: .ok) let body = ByteBuffer(string: "foo bar") - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.channelRead(.body(body)), .wait) - state.channelRead(.end(nil)).assertFailRequest(HTTPClientError.remoteConnectionClosed, .close) + state.channelRead(.end(nil)).assertFailRequest(HTTPClientError.remoteConnectionClosed, .close(nil)) XCTAssertEqual(state.channelInactive(), .wait) } @@ -503,14 +766,20 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) let body = ByteBuffer(string: "foo bar") - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.channelRead(.body(body)), .wait) - state.errorHappened(NIOSSLError.uncleanShutdown).assertFailRequest(NIOSSLError.uncleanShutdown, .close) + state.errorHappened(NIOSSLError.uncleanShutdown).assertFailRequest(NIOSSLError.uncleanShutdown, .close(nil)) XCTAssertEqual(state.channelRead(.end(nil)), .wait) XCTAssertEqual(state.channelInactive(), .wait) } @@ -519,7 +788,10 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) XCTAssertEqual(state.errorHappened(NIOSSLError.uncleanShutdown), .wait) state.channelInactive().assertFailRequest(HTTPClientError.remoteConnectionClosed, .none) @@ -530,9 +802,12 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) - state.errorHappened(ArbitraryError()).assertFailRequest(ArbitraryError(), .close) + state.errorHappened(ArbitraryError()).assertFailRequest(ArbitraryError(), .close(nil)) XCTAssertEqual(state.channelInactive(), .wait) } @@ -540,17 +815,26 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["content-length": "30"]) let body = ByteBuffer(string: "foo bar") - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.read(), .read) XCTAssertEqual(state.channelRead(.body(body)), .wait) XCTAssertEqual(state.channelReadComplete(), .forwardResponseBodyParts([body])) XCTAssertEqual(state.errorHappened(NIOSSLError.uncleanShutdown), .wait) - state.errorHappened(HTTPParserError.invalidEOFState).assertFailRequest(HTTPParserError.invalidEOFState, .close) + state.errorHappened(HTTPParserError.invalidEOFState).assertFailRequest( + HTTPParserError.invalidEOFState, + .close(nil) + ) XCTAssertEqual(state.channelInactive(), .wait) } @@ -558,11 +842,17 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["Content-Length": "50"]) let body = ByteBuffer(string: "foo bar") - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) @@ -579,11 +869,17 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["Content-Length": "50"]) let body = ByteBuffer(string: "foo bar") - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) @@ -600,11 +896,17 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["Content-Length": "50"]) let body = ByteBuffer(string: "foo bar") - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) @@ -620,11 +922,17 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["Content-Length": "50"]) let body = ByteBuffer(string: "foo bar") - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) @@ -656,24 +964,36 @@ extension HTTPRequestStateMachine.Action: Equatable { case (.sendRequestHead(let lhsHead, let lhsStartBody), .sendRequestHead(let rhsHead, let rhsStartBody)): return lhsHead == rhsHead && lhsStartBody == rhsStartBody - case (.sendBodyPart(let lhsData), .sendBodyPart(let rhsData)): - return lhsData == rhsData + case ( + .notifyRequestHeadSendSuccessfully(let lhsResumeRequestBodyStream, let lhsStartIdleTimer), + .notifyRequestHeadSendSuccessfully(let rhsResumeRequestBodyStream, let rhsStartIdleTimer) + ): + return lhsResumeRequestBodyStream == rhsResumeRequestBodyStream && lhsStartIdleTimer == rhsStartIdleTimer - case (.sendRequestEnd, .sendRequestEnd): - return true + case (.sendBodyPart(let lhsData, let lhsPromise), .sendBodyPart(let rhsData, let rhsPromise)): + return lhsData == rhsData && lhsPromise?.futureResult == rhsPromise?.futureResult + + case (.sendRequestEnd(let lhsPromise), .sendRequestEnd(let rhsPromise)): + return lhsPromise?.futureResult == rhsPromise?.futureResult case (.pauseRequestBodyStream, .pauseRequestBodyStream): return true case (.resumeRequestBodyStream, .resumeRequestBodyStream): return true - case (.forwardResponseHead(let lhsHead, let lhsPauseRequestBodyStream), .forwardResponseHead(let rhsHead, let rhsPauseRequestBodyStream)): + case ( + .forwardResponseHead(let lhsHead, let lhsPauseRequestBodyStream), + .forwardResponseHead(let rhsHead, let rhsPauseRequestBodyStream) + ): return lhsHead == rhsHead && lhsPauseRequestBodyStream == rhsPauseRequestBodyStream case (.forwardResponseBodyParts(let lhsData), .forwardResponseBodyParts(let rhsData)): return lhsData == rhsData - case (.succeedRequest(let lhsFinalAction, let lhsFinalBuffer), .succeedRequest(let rhsFinalAction, let rhsFinalBuffer)): + case ( + .succeedRequest(let lhsFinalAction, let lhsFinalBuffer), + .succeedRequest(let rhsFinalAction, let rhsFinalBuffer) + ): return lhsFinalAction == rhsFinalAction && lhsFinalBuffer == rhsFinalBuffer case (.failRequest(_, let lhsFinalAction), .failRequest(_, let rhsFinalAction)): @@ -685,6 +1005,57 @@ extension HTTPRequestStateMachine.Action: Equatable { case (.wait, .wait): return true + case ( + .failSendBodyPart(let lhsError as HTTPClientError, let lhsPromise), + .failSendBodyPart(let rhsError as HTTPClientError, let rhsPromise) + ): + return lhsError == rhsError && lhsPromise?.futureResult == rhsPromise?.futureResult + + case ( + .failSendStreamFinished(let lhsError as HTTPClientError, let lhsPromise), + .failSendStreamFinished(let rhsError as HTTPClientError, let rhsPromise) + ): + return lhsError == rhsError && lhsPromise?.futureResult == rhsPromise?.futureResult + + default: + return false + } + } +} + +extension HTTPRequestStateMachine.Action.FinalSuccessfulRequestAction: Equatable { + public static func == ( + lhs: HTTPRequestStateMachine.Action.FinalSuccessfulRequestAction, + rhs: HTTPRequestStateMachine.Action.FinalSuccessfulRequestAction + ) -> Bool { + switch (lhs, rhs) { + case (.close, close): + return true + + case (.sendRequestEnd(let lhsPromise), .sendRequestEnd(let rhsPromise)): + return lhsPromise?.futureResult == rhsPromise?.futureResult + + case (.none, .none): + return true + + default: + return false + } + } +} + +extension HTTPRequestStateMachine.Action.FinalFailedRequestAction: Equatable { + public static func == ( + lhs: HTTPRequestStateMachine.Action.FinalFailedRequestAction, + rhs: HTTPRequestStateMachine.Action.FinalFailedRequestAction + ) -> Bool { + switch (lhs, rhs) { + case (.close(let lhsPromise), close(let rhsPromise)): + return lhsPromise?.futureResult == rhsPromise?.futureResult + + case (.none, .none): + return true + default: return false } @@ -694,12 +1065,16 @@ extension HTTPRequestStateMachine.Action: Equatable { extension HTTPRequestStateMachine.Action { fileprivate func assertFailRequest( _ expectedError: Error, - _ expectedFinalStreamAction: HTTPRequestStateMachine.Action.FinalStreamAction, - file: StaticString = #file, + _ expectedFinalStreamAction: HTTPRequestStateMachine.Action.FinalFailedRequestAction, + file: StaticString = #filePath, line: UInt = #line ) where Error: Swift.Error & Equatable { guard case .failRequest(let actualError, let actualFinalStreamAction) = self else { - return XCTFail("expected .failRequest(\(expectedError), \(expectedFinalStreamAction)) but got \(self)", file: file, line: line) + return XCTFail( + "expected .failRequest(\(expectedError), \(expectedFinalStreamAction)) but got \(self)", + file: file, + line: line + ) } if let actualError = actualError as? Error { XCTAssertEqual(actualError, expectedError, file: file, line: line) diff --git a/Tests/AsyncHTTPClientTests/IdleTimeoutNoReuseTests.swift b/Tests/AsyncHTTPClientTests/IdleTimeoutNoReuseTests.swift new file mode 100644 index 000000000..e9a0d46dc --- /dev/null +++ b/Tests/AsyncHTTPClientTests/IdleTimeoutNoReuseTests.swift @@ -0,0 +1,41 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2022 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import AsyncHTTPClient +import Atomics +import Logging +import NIOConcurrencyHelpers +import NIOCore +import NIOFoundationCompat +import NIOHTTP1 +import NIOHTTPCompression +import NIOPosix +import NIOSSL +import NIOTestUtils +import NIOTransportServices +import XCTest + +#if canImport(Network) +import Network +#endif + +final class TestIdleTimeoutNoReuse: XCTestCaseHTTPClientTestsBaseClass { + func testIdleTimeoutNoReuse() throws { + var req = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get", method: .GET) + XCTAssertNoThrow(try self.defaultClient.execute(request: req, deadline: .now() + .seconds(2)).wait()) + req.headers.add(name: "X-internal-delay", value: "2500") + try self.defaultClient.eventLoopGroup.next().scheduleTask(in: .milliseconds(250)) {}.futureResult.wait() + XCTAssertNoThrow(try self.defaultClient.execute(request: req).timeout(after: .seconds(10)).wait()) + } +} diff --git a/Tests/AsyncHTTPClientTests/LRUCacheTests+XCTest.swift b/Tests/AsyncHTTPClientTests/LRUCacheTests+XCTest.swift deleted file mode 100644 index a0231bf0d..000000000 --- a/Tests/AsyncHTTPClientTests/LRUCacheTests+XCTest.swift +++ /dev/null @@ -1,33 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// LRUCacheTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension LRUCacheTests { - static var allTests: [(String, (LRUCacheTests) -> () throws -> Void)] { - return [ - ("testBasicsWork", testBasicsWork), - ("testCachesTheRightThings", testCachesTheRightThings), - ("testAppendingTheSameDoesNotEvictButUpdates", testAppendingTheSameDoesNotEvictButUpdates), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/LRUCacheTests.swift b/Tests/AsyncHTTPClientTests/LRUCacheTests.swift index 6392bcebe..6173c34eb 100644 --- a/Tests/AsyncHTTPClientTests/LRUCacheTests.swift +++ b/Tests/AsyncHTTPClientTests/LRUCacheTests.swift @@ -12,9 +12,10 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import XCTest +@testable import AsyncHTTPClient + class LRUCacheTests: XCTestCase { func testBasicsWork() { var cache = LRUCache(capacity: 1) diff --git a/Tests/AsyncHTTPClientTests/Mocks/MockConnectionPool.swift b/Tests/AsyncHTTPClientTests/Mocks/MockConnectionPool.swift index eedc499ad..a6b48fb9a 100644 --- a/Tests/AsyncHTTPClientTests/Mocks/MockConnectionPool.swift +++ b/Tests/AsyncHTTPClientTests/Mocks/MockConnectionPool.swift @@ -12,12 +12,13 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOCore import NIOHTTP1 import NIOSSL +@testable import AsyncHTTPClient + /// A mock connection pool (not creating any actual connections) that is used to validate /// connection actions returned by the `HTTPConnectionPool.StateMachine`. struct MockConnectionPool { @@ -541,17 +542,23 @@ extension MockConnectionPool { ) throws -> (Self, HTTPConnectionPool.StateMachine) { var state = HTTPConnectionPool.StateMachine( idGenerator: .init(), - maximumConcurrentHTTP1Connections: maxNumberOfConnections + maximumConcurrentHTTP1Connections: maxNumberOfConnections, + retryConnectionEstablishment: true, + preferHTTP1: true, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 ) var connections = MockConnectionPool() var queuer = MockRequestQueuer() for _ in 0.. (Self, HTTPConnectionPool.StateMachine) { var state = HTTPConnectionPool.StateMachine( idGenerator: .init(), - maximumConcurrentHTTP1Connections: 8 + maximumConcurrentHTTP1Connections: 8, + retryConnectionEstablishment: true, + preferHTTP1: false, + maximumConnectionUses: nil, + preWarmedHTTP1ConnectionCount: 0 ) var connections = MockConnectionPool() var queuer = MockRequestQueuer() // 1. Schedule one request to create a connection - let mockRequest = MockHTTPRequest(eventLoop: eventLoop ?? elg.next()) + let mockRequest = MockHTTPScheduableRequest(eventLoop: eventLoop ?? elg.next()) let request = HTTPConnectionPool.Request(mockRequest) let executeAction = state.executeRequest(request) - guard case .scheduleRequestTimeout(request, on: let waitEL) = executeAction.request, mockRequest.eventLoop === waitEL else { + guard case .scheduleRequestTimeout(request, on: let waitEL) = executeAction.request, + mockRequest.eventLoop === waitEL + else { throw SetupError.expectedRequestToBeAddedToQueue } @@ -628,17 +641,16 @@ extension MockConnectionPool { // 2. the connection becomes available - let newConnection = try connections.succeedConnectionCreationHTTP2(connectionID, maxConcurrentStreams: maxConcurrentStreams) + let newConnection = try connections.succeedConnectionCreationHTTP2( + connectionID, + maxConcurrentStreams: maxConcurrentStreams + ) let action = state.newHTTP2ConnectionCreated(newConnection, maxConcurrentStreams: maxConcurrentStreams) guard case .executeRequestsAndCancelTimeouts([request], newConnection) = action.request else { throw SetupError.expectedPreviouslyQueuedRequestToBeRunNow } - guard case .migration(createConnections: let create, closeConnections: [], scheduleTimeout: nil) = action.connection, create.isEmpty else { - throw SetupError.expectedNoConnectionAction - } - guard try queuer.get(request.id, request: request.__testOnly_wrapped_request()) === mockRequest else { throw SetupError.expectedPreviouslyQueuedRequestToBeRunNow } @@ -664,7 +676,7 @@ extension MockConnectionPool { /// A request that can be used when testing the `HTTPConnectionPool.StateMachine` /// with the `MockConnectionPool`. -class MockHTTPRequest: HTTPSchedulableRequest { +final class MockHTTPScheduableRequest: HTTPSchedulableRequest { let logger: Logger let connectionDeadline: NIODeadline let requestOptions: RequestOptions @@ -672,10 +684,12 @@ class MockHTTPRequest: HTTPSchedulableRequest { let preferredEventLoop: EventLoop let requiredEventLoop: EventLoop? - init(eventLoop: EventLoop, - logger: Logger = Logger(label: "mock"), - connectionTimeout: TimeAmount = .seconds(60), - requiresEventLoopForChannel: Bool = false) { + init( + eventLoop: EventLoop, + logger: Logger = Logger(label: "mock"), + connectionTimeout: TimeAmount = .seconds(60), + requiresEventLoopForChannel: Bool = false + ) { self.logger = logger self.connectionDeadline = .now() + connectionTimeout @@ -690,7 +704,7 @@ class MockHTTPRequest: HTTPSchedulableRequest { } var eventLoop: EventLoop { - return self.preferredEventLoop + self.preferredEventLoop } // MARK: HTTPSchedulableRequest diff --git a/Tests/AsyncHTTPClientTests/Mocks/MockHTTPExecutableRequest.swift b/Tests/AsyncHTTPClientTests/Mocks/MockHTTPExecutableRequest.swift new file mode 100644 index 000000000..67f18cbb8 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/Mocks/MockHTTPExecutableRequest.swift @@ -0,0 +1,175 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Logging +import NIOConcurrencyHelpers +import NIOCore +import NIOHTTP1 +import XCTest + +@testable import AsyncHTTPClient + +final class MockHTTPExecutableRequest: HTTPExecutableRequest { + enum Event: Sendable { + /// ``Event`` without associated values + enum Kind: Hashable { + case willExecuteRequest + case requestHeadSent + case resumeRequestBodyStream + case pauseRequestBodyStream + case receiveResponseHead + case receiveResponseBodyParts + case succeedRequest + case fail + } + + case willExecuteRequest(HTTPRequestExecutor) + case requestHeadSent + case resumeRequestBodyStream + case pauseRequestBodyStream + case receiveResponseHead(HTTPResponseHead) + case receiveResponseBodyParts(CircularBuffer) + case succeedRequest(CircularBuffer?) + case fail(Error) + + var kind: Kind { + switch self { + case .willExecuteRequest: return .willExecuteRequest + case .requestHeadSent: return .requestHeadSent + case .resumeRequestBodyStream: return .resumeRequestBodyStream + case .pauseRequestBodyStream: return .pauseRequestBodyStream + case .receiveResponseHead: return .receiveResponseHead + case .receiveResponseBodyParts: return .receiveResponseBodyParts + case .succeedRequest: return .succeedRequest + case .fail: return .fail + } + } + } + + let logger: Logging.Logger = Logger(label: "request") + let requestHead: NIOHTTP1.HTTPRequestHead + let requestFramingMetadata: RequestFramingMetadata + let requestOptions: RequestOptions = .forTests() + + /// if true and ``HTTPExecutableRequest`` method is called without setting a corresponding callback on `self` e.g. + /// If ``HTTPExecutableRequest\.willExecuteRequest(_:)`` is called but ``willExecuteRequestCallback`` is not set, + /// ``XCTestFail(_:)`` will be called to fail the current test. + let raiseErrorIfUnimplementedMethodIsCalled: Bool + private let file: StaticString + private let line: UInt + + let willExecuteRequestCallback: (@Sendable (HTTPRequestExecutor) -> Void)? = nil + let requestHeadSentCallback: (@Sendable () -> Void)? = nil + let resumeRequestBodyStreamCallback: (@Sendable () -> Void)? = nil + let pauseRequestBodyStreamCallback: (@Sendable () -> Void)? = nil + let receiveResponseHeadCallback: (@Sendable (HTTPResponseHead) -> Void)? = nil + let receiveResponseBodyPartsCallback: (@Sendable (CircularBuffer) -> Void)? = nil + let succeedRequestCallback: (@Sendable (CircularBuffer?) -> Void)? = nil + let failCallback: (@Sendable (Error) -> Void)? = nil + + /// captures all ``HTTPExecutableRequest`` method calls in the order of occurrence, including arguments. + /// If you are not interested in the arguments you can use `events.map(\.kind)` to get all events without arguments. + private let _events = NIOLockedValueBox<[Event]>([]) + private(set) var events: [Event] { + get { + self._events.withLockedValue { $0 } + } + set { + self._events.withLockedValue { $0 = newValue } + } + } + + init( + head: NIOHTTP1.HTTPRequestHead = .init(version: .http1_1, method: .GET, uri: "http://localhost/"), + framingMetadata: RequestFramingMetadata = .init(connectionClose: false, body: .fixedSize(0)), + raiseErrorIfUnimplementedMethodIsCalled: Bool = true, + file: StaticString = #file, + line: UInt = #line + ) { + self.requestHead = head + self.requestFramingMetadata = framingMetadata + self.raiseErrorIfUnimplementedMethodIsCalled = raiseErrorIfUnimplementedMethodIsCalled + self.file = file + self.line = line + } + + private func calledUnimplementedMethod(_ name: String) { + guard self.raiseErrorIfUnimplementedMethodIsCalled else { return } + XCTFail("\(name) invoked but it is not implemented", file: self.file, line: self.line) + } + + func willExecuteRequest(_ executor: HTTPRequestExecutor) { + self.events.append(.willExecuteRequest(executor)) + guard let willExecuteRequestCallback = willExecuteRequestCallback else { + return self.calledUnimplementedMethod(#function) + } + willExecuteRequestCallback(executor) + } + + func requestHeadSent() { + self.events.append(.requestHeadSent) + guard let requestHeadSentCallback = requestHeadSentCallback else { + return self.calledUnimplementedMethod(#function) + } + requestHeadSentCallback() + } + + func resumeRequestBodyStream() { + self.events.append(.resumeRequestBodyStream) + guard let resumeRequestBodyStreamCallback = resumeRequestBodyStreamCallback else { + return self.calledUnimplementedMethod(#function) + } + resumeRequestBodyStreamCallback() + } + + func pauseRequestBodyStream() { + self.events.append(.pauseRequestBodyStream) + guard let pauseRequestBodyStreamCallback = pauseRequestBodyStreamCallback else { + return self.calledUnimplementedMethod(#function) + } + pauseRequestBodyStreamCallback() + } + + func receiveResponseHead(_ head: HTTPResponseHead) { + self.events.append(.receiveResponseHead(head)) + guard let receiveResponseHeadCallback = receiveResponseHeadCallback else { + return self.calledUnimplementedMethod(#function) + } + receiveResponseHeadCallback(head) + } + + func receiveResponseBodyParts(_ buffer: CircularBuffer) { + self.events.append(.receiveResponseBodyParts(buffer)) + guard let receiveResponseBodyPartsCallback = receiveResponseBodyPartsCallback else { + return self.calledUnimplementedMethod(#function) + } + receiveResponseBodyPartsCallback(buffer) + } + + func succeedRequest(_ buffer: CircularBuffer?) { + self.events.append(.succeedRequest(buffer)) + guard let succeedRequestCallback = succeedRequestCallback else { + return self.calledUnimplementedMethod(#function) + } + succeedRequestCallback(buffer) + } + + func fail(_ error: Error) { + self.events.append(.fail(error)) + guard let failCallback = failCallback else { + return self.calledUnimplementedMethod(#function) + } + failCallback(error) + } +} diff --git a/Tests/AsyncHTTPClientTests/Mocks/MockRequestExecutor.swift b/Tests/AsyncHTTPClientTests/Mocks/MockRequestExecutor.swift index b5b67c809..e5d9caa8e 100644 --- a/Tests/AsyncHTTPClientTests/Mocks/MockRequestExecutor.swift +++ b/Tests/AsyncHTTPClientTests/Mocks/MockRequestExecutor.swift @@ -12,10 +12,11 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOConcurrencyHelpers import NIOCore +@testable import AsyncHTTPClient + // This is a MockRequestExecutor, that is synchronized on its EventLoop. final class MockRequestExecutor { enum Errors: Error { @@ -24,7 +25,7 @@ final class MockRequestExecutor { case unexpectedByteBuffer } - enum RequestParts: Equatable { + enum RequestParts: Equatable, Sendable { case body(IOData) case endOfStream @@ -47,7 +48,7 @@ final class MockRequestExecutor { } var requestBodyPartsCount: Int { - return self.blockingQueue.count + self.blockingQueue.count } let eventLoop: EventLoop @@ -57,10 +58,15 @@ final class MockRequestExecutor { private let responseBodyDemandLock = ConditionLock(value: false) private let cancellationLock = ConditionLock(value: false) - private var request: HTTPExecutableRequest? - private var _signaledDemandForRequestBody: Bool = false + private struct State: Sendable { + var request: HTTPExecutableRequest? + var _signaledDemandForRequestBody: Bool = false + } + + private let state: NIOLockedValueBox init(pauseRequestBodyPartStreamAfterASingleWrite: Bool = false, eventLoop: EventLoop) { + self.state = NIOLockedValueBox(State()) self.pauseRequestBodyPartStreamAfterASingleWrite = pauseRequestBodyPartStreamAfterASingleWrite self.eventLoop = eventLoop } @@ -76,13 +82,16 @@ final class MockRequestExecutor { } private func runRequest0(_ request: HTTPExecutableRequest) { - precondition(self.request == nil) - self.request = request + self.state.withLockedValue { + precondition($0.request == nil) + $0.request = request + } request.willExecuteRequest(self) request.requestHeadSent() } - func receiveRequestBody(deadline: NIODeadline = .now() + .seconds(5), _ verify: (ByteBuffer) throws -> Void) throws { + func receiveRequestBody(deadline: NIODeadline = .now() + .seconds(5), _ verify: (ByteBuffer) throws -> Void) throws + { enum ReceiveAction { case value(RequestParts) case future(EventLoopFuture) @@ -125,10 +134,16 @@ final class MockRequestExecutor { } private func pauseRequestBodyStream0() { - if self._signaledDemandForRequestBody == true { - self._signaledDemandForRequestBody = false - self.request!.pauseRequestBodyStream() + let request = self.state.withLockedValue { + if $0._signaledDemandForRequestBody == true { + $0._signaledDemandForRequestBody = false + return $0.request + } else { + return nil + } } + + request?.pauseRequestBodyStream() } func resumeRequestBodyStream() { @@ -142,10 +157,16 @@ final class MockRequestExecutor { } private func resumeRequestBodyStream0() { - if self._signaledDemandForRequestBody == false { - self._signaledDemandForRequestBody = true - self.request!.resumeRequestBodyStream() + let request = self.state.withLockedValue { + if $0._signaledDemandForRequestBody == false { + $0._signaledDemandForRequestBody = true + return $0.request + } else { + return nil + } } + + request?.resumeRequestBodyStream() } func resetResponseStreamDemandSignal() { @@ -155,10 +176,11 @@ final class MockRequestExecutor { func receiveResponseDemand(deadline: NIODeadline = .now() + .seconds(5)) throws { let secondsUntilDeath = deadline - NIODeadline.now() - guard self.responseBodyDemandLock.lock( - whenValue: true, - timeoutSeconds: .init(secondsUntilDeath.nanoseconds / 1_000_000_000) - ) + guard + self.responseBodyDemandLock.lock( + whenValue: true, + timeoutSeconds: .init(secondsUntilDeath.nanoseconds / 1_000_000_000) + ) else { throw TimeoutError() } @@ -168,10 +190,11 @@ final class MockRequestExecutor { func receiveCancellation(deadline: NIODeadline = .now() + .seconds(5)) throws { let secondsUntilDeath = deadline - NIODeadline.now() - guard self.cancellationLock.lock( - whenValue: true, - timeoutSeconds: .init(secondsUntilDeath.nanoseconds / 1_000_000_000) - ) + guard + self.cancellationLock.lock( + whenValue: true, + timeoutSeconds: .init(secondsUntilDeath.nanoseconds / 1_000_000_000) + ) else { throw TimeoutError() } @@ -184,12 +207,14 @@ extension MockRequestExecutor: HTTPRequestExecutor { // this should always be called twice. When we receive the first call, the next call to produce // data is already scheduled. If we call pause here, once, after the second call new subsequent // calls should not be scheduled. - func writeRequestBodyPart(_ part: IOData, request: HTTPExecutableRequest) { + func writeRequestBodyPart(_ part: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise?) { self.writeNextRequestPart(.body(part), request: request) + promise?.succeed(()) } - func finishRequestBodyStream(_ request: HTTPExecutableRequest) { + func finishRequestBodyStream(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { self.writeNextRequestPart(.endOfStream, request: request) + promise?.succeed(()) } private func writeNextRequestPart(_ part: RequestParts, request: HTTPExecutableRequest) { @@ -198,11 +223,13 @@ extension MockRequestExecutor: HTTPRequestExecutor { case none } - let stateChange = { () -> WriteAction in + let stateChange = { @Sendable () -> WriteAction in var pause = false if self.blockingQueue.isEmpty && self.pauseRequestBodyPartStreamAfterASingleWrite && part.isBody { pause = true - self._signaledDemandForRequestBody = false + self.state.withLockedValue { + $0._signaledDemandForRequestBody = false + } } self.blockingQueue.append(.success(part)) @@ -263,8 +290,12 @@ extension MockRequestExecutor { internal func popFirst(deadline: NIODeadline) throws -> Element { let secondsUntilDeath = deadline - NIODeadline.now() - guard self.condition.lock(whenValue: true, - timeoutSeconds: .init(secondsUntilDeath.nanoseconds / 1_000_000_000)) else { + guard + self.condition.lock( + whenValue: true, + timeoutSeconds: .init(secondsUntilDeath.nanoseconds / 1_000_000_000) + ) + else { throw TimeoutError() } let first = self.buffer.removeFirst() @@ -273,3 +304,5 @@ extension MockRequestExecutor { } } } + +extension MockRequestExecutor.BlockingQueue: @unchecked Sendable where Element: Sendable {} diff --git a/Tests/AsyncHTTPClientTests/Mocks/MockRequestQueuer.swift b/Tests/AsyncHTTPClientTests/Mocks/MockRequestQueuer.swift index e81f1ed0a..44e820444 100644 --- a/Tests/AsyncHTTPClientTests/Mocks/MockRequestQueuer.swift +++ b/Tests/AsyncHTTPClientTests/Mocks/MockRequestQueuer.swift @@ -12,11 +12,12 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOCore import NIOHTTP1 +@testable import AsyncHTTPClient + /// A mock request queue (not creating any timers) that is used to validate /// request actions returned by the `HTTPConnectionPool.StateMachine`. struct MockRequestQueuer { @@ -82,11 +83,11 @@ struct MockRequestQueuer { return waiter.request } - mutating func timeoutRandomRequest() -> RequestID? { - guard let waiterID = self.waiters.randomElement().map(\.0) else { + mutating func timeoutRandomRequest() -> (RequestID, HTTPSchedulableRequest)? { + guard let waiter = self.waiters.randomElement() else { return nil } - self.waiters.removeValue(forKey: waiterID) - return waiterID + self.waiters.removeValue(forKey: waiter.key) + return (waiter.key, waiter.value.request) } } diff --git a/Tests/AsyncHTTPClientTests/NWWaitingHandlerTests.swift b/Tests/AsyncHTTPClientTests/NWWaitingHandlerTests.swift new file mode 100644 index 000000000..63eaf649d --- /dev/null +++ b/Tests/AsyncHTTPClientTests/NWWaitingHandlerTests.swift @@ -0,0 +1,108 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2023 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#if canImport(Network) +@testable import AsyncHTTPClient +import Network +import NIOCore +import NIOConcurrencyHelpers +import NIOEmbedded +import NIOSSL +import NIOTransportServices +import XCTest + +@available(macOS 10.14, iOS 12.0, tvOS 12.0, watchOS 5.0, *) +class NWWaitingHandlerTests: XCTestCase { + final class MockRequester: HTTPConnectionRequester { + private struct State: Sendable { + var waitingForConnectivityCalled = false + var connectionID: AsyncHTTPClient.HTTPConnectionPool.Connection.ID? + var transientError: NWError? + } + + private let state = NIOLockedValueBox(State()) + + var waitingForConnectivityCalled: Bool { + self.state.withLockedValue { $0.waitingForConnectivityCalled } + } + + var connectionID: AsyncHTTPClient.HTTPConnectionPool.Connection.ID? { + self.state.withLockedValue { $0.connectionID } + } + + var transientError: NWError? { + self.state.withLockedValue { + $0.transientError + } + } + + func http1ConnectionCreated(_: AsyncHTTPClient.HTTP1Connection.SendableView) {} + + func http2ConnectionCreated(_: AsyncHTTPClient.HTTP2Connection.SendableView, maximumStreams: Int) {} + + func failedToCreateHTTPConnection(_: AsyncHTTPClient.HTTPConnectionPool.Connection.ID, error: Error) {} + + func waitingForConnectivity(_ connectionID: AsyncHTTPClient.HTTPConnectionPool.Connection.ID, error: Error) { + self.state.withLockedValue { + $0.waitingForConnectivityCalled = true + $0.connectionID = connectionID + $0.transientError = error as? NWError + } + } + } + + func testWaitingHandlerInvokesWaitingForConnectivity() { + let requester = MockRequester() + let connectionID: AsyncHTTPClient.HTTPConnectionPool.Connection.ID = 1 + let waitingEventHandler = NWWaitingHandler(requester: requester, connectionID: connectionID) + let embedded = EmbeddedChannel(handlers: [waitingEventHandler]) + + embedded.pipeline.fireUserInboundEventTriggered( + NIOTSNetworkEvents.WaitingForConnectivity(transientError: .dns(1)) + ) + + XCTAssertTrue( + requester.waitingForConnectivityCalled, + "Expected the handler to invoke .waitingForConnectivity on the requester" + ) + XCTAssertEqual(requester.connectionID, connectionID, "Expected the handler to pass connectionID to requester") + XCTAssertEqual(requester.transientError, NWError.dns(1)) + } + + func testWaitingHandlerDoesNotInvokeWaitingForConnectionOnUnrelatedErrors() { + let requester = MockRequester() + let waitingEventHandler = NWWaitingHandler(requester: requester, connectionID: 1) + let embedded = EmbeddedChannel(handlers: [waitingEventHandler]) + embedded.pipeline.fireUserInboundEventTriggered(NIOTSNetworkEvents.BetterPathAvailable()) + + XCTAssertFalse( + requester.waitingForConnectivityCalled, + "Should not call .waitingForConnectivity on unrelated events" + ) + } + + func testWaitingHandlerPassesTheEventDownTheContext() { + let requester = MockRequester() + let waitingEventHandler = NWWaitingHandler(requester: requester, connectionID: 1) + let tlsEventsHandler = TLSEventsHandler(deadline: nil) + let embedded = EmbeddedChannel(handlers: [waitingEventHandler, tlsEventsHandler]) + + embedded.pipeline.fireErrorCaught(NIOSSLError.handshakeFailed(BoringSSLError.wantConnect)) + XCTAssertThrowsError(try XCTUnwrap(tlsEventsHandler.tlsEstablishedFuture).wait()) { + XCTAssertEqualTypeAndValue($0, NIOSSLError.handshakeFailed(BoringSSLError.wantConnect)) + } + } +} + +#endif diff --git a/Tests/AsyncHTTPClientTests/NoBytesSentOverBodyLimitTests.swift b/Tests/AsyncHTTPClientTests/NoBytesSentOverBodyLimitTests.swift new file mode 100644 index 000000000..026a45d4c --- /dev/null +++ b/Tests/AsyncHTTPClientTests/NoBytesSentOverBodyLimitTests.swift @@ -0,0 +1,81 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2022 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import AsyncHTTPClient +import Atomics +import Logging +import NIOConcurrencyHelpers +import NIOCore +import NIOFoundationCompat +import NIOHTTP1 +import NIOHTTPCompression +import NIOPosix +import NIOSSL +import NIOTestUtils +import NIOTransportServices +import XCTest + +#if canImport(Network) +import Network +#endif + +final class NoBytesSentOverBodyLimitTests: XCTestCaseHTTPClientTestsBaseClass { + func testNoBytesSentOverBodyLimit() throws { + let server = NIOHTTP1TestServer(group: self.serverGroup) + defer { + XCTAssertNoThrow(try server.stop()) + } + + let tooLong = "XBAD BAD BAD NOT HTTP/1.1\r\n\r\n" + + let request = try Request( + url: "http://localhost:\(server.serverPort)", + body: .stream(contentLength: 1) { streamWriter in + streamWriter.write(.byteBuffer(ByteBuffer(string: tooLong))) + } + ) + + let future = self.defaultClient.execute(request: request) + + // Okay, what happens here needs an explanation: + // + // In the request state machine, we should start the request, which will lead to an + // invocation of `context.write(HTTPRequestHead)`. Since we will receive a streamed request + // body a `context.flush()` will be issued. Further the request stream will be started. + // Since the request stream immediately produces to much data, the request will be failed + // and the connection will be closed. + // + // Even though a flush was issued after the request head, there is no guarantee that the + // request head was written to the network. For this reason we must accept not receiving a + // request and receiving a request head. + + do { + _ = try server.receiveHead() + + // A request head was sent. We expect the request now to fail with a parsing error, + // since the client ended the connection to early (from the server's point of view.) + XCTAssertThrowsError(try server.readInbound()) { + XCTAssertEqual($0 as? HTTPParserError, HTTPParserError.invalidEOFState) + } + } catch { + // TBD: We sadly can't verify the error type, since it is private in `NIOTestUtils`: + // NIOTestUtils.BlockingQueue.TimeoutError + } + + // request must always be failed with this error + XCTAssertThrowsError(try future.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .bodyLengthMismatch) + } + } +} diff --git a/Tests/AsyncHTTPClientTests/RacePoolIdleConnectionsAndGetTests.swift b/Tests/AsyncHTTPClientTests/RacePoolIdleConnectionsAndGetTests.swift new file mode 100644 index 000000000..35a09c421 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/RacePoolIdleConnectionsAndGetTests.swift @@ -0,0 +1,47 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2022 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import AsyncHTTPClient +import Atomics +import Logging +import NIOConcurrencyHelpers +import NIOCore +import NIOFoundationCompat +import NIOHTTP1 +import NIOHTTPCompression +import NIOPosix +import NIOSSL +import NIOTestUtils +import NIOTransportServices +import XCTest + +#if canImport(Network) +import Network +#endif + +final class RacePoolIdleConnectionsAndGetTests: XCTestCaseHTTPClientTestsBaseClass { + func testRacePoolIdleConnectionsAndGet() { + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init(connectionPool: .init(idleTimeout: .milliseconds(10))) + ) + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + } + for _ in 1...200 { + XCTAssertNoThrow(try localClient.get(url: self.defaultHTTPBinURLPrefix + "get").wait()) + Thread.sleep(forTimeInterval: 0.01 + .random(in: -0.01...0.01)) + } + } +} diff --git a/Tests/AsyncHTTPClientTests/RequestBagTests+XCTest.swift b/Tests/AsyncHTTPClientTests/RequestBagTests+XCTest.swift deleted file mode 100644 index 74c68fd1f..000000000 --- a/Tests/AsyncHTTPClientTests/RequestBagTests+XCTest.swift +++ /dev/null @@ -1,42 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// RequestBagTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension RequestBagTests { - static var allTests: [(String, (RequestBagTests) -> () throws -> Void)] { - return [ - ("testWriteBackpressureWorks", testWriteBackpressureWorks), - ("testTaskIsFailedIfWritingFails", testTaskIsFailedIfWritingFails), - ("testCancelFailsTaskBeforeRequestIsSent", testCancelFailsTaskBeforeRequestIsSent), - ("testCancelFailsTaskAfterRequestIsSent", testCancelFailsTaskAfterRequestIsSent), - ("testCancelFailsTaskWhenTaskIsQueued", testCancelFailsTaskWhenTaskIsQueued), - ("testFailsTaskWhenTaskIsWaitingForMoreFromServer", testFailsTaskWhenTaskIsWaitingForMoreFromServer), - ("testChannelBecomingWritableDoesntCrashCancelledTask", testChannelBecomingWritableDoesntCrashCancelledTask), - ("testHTTPUploadIsCancelledEvenThoughRequestSucceeds", testHTTPUploadIsCancelledEvenThoughRequestSucceeds), - ("testRaceBetweenConnectionCloseAndDemandMoreData", testRaceBetweenConnectionCloseAndDemandMoreData), - ("testRedirectWith3KBBody", testRedirectWith3KBBody), - ("testRedirectWith4KBBodyAnnouncedInResponseHead", testRedirectWith4KBBodyAnnouncedInResponseHead), - ("testRedirectWith4KBBodyNotAnnouncedInResponseHead", testRedirectWith4KBBodyNotAnnouncedInResponseHead), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/RequestBagTests.swift b/Tests/AsyncHTTPClientTests/RequestBagTests.swift index c80f8846b..f1600fceb 100644 --- a/Tests/AsyncHTTPClientTests/RequestBagTests.swift +++ b/Tests/AsyncHTTPClientTests/RequestBagTests.swift @@ -12,37 +12,54 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient +import Atomics import Logging +import NIOConcurrencyHelpers import NIOCore import NIOEmbedded import NIOHTTP1 +import NIOPosix import XCTest +@testable import AsyncHTTPClient + final class RequestBagTests: XCTestCase { func testWriteBackpressureWorks() { let embeddedEventLoop = EmbeddedEventLoop() defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } let logger = Logger(label: "test") - var writtenBytes = 0 - var writes = 0 + struct TestState { + var writtenBytes: Int = 0 + var writes: Int = 0 + var streamIsAllowedToWrite: Bool = false + } + + let testState = NIOLockedValueBox(TestState()) + let bytesToSent = (3000...10000).randomElement()! let expectedWrites = bytesToSent / 100 + ((bytesToSent % 100 > 0) ? 1 : 0) - var streamIsAllowedToWrite = false let writeDonePromise = embeddedEventLoop.makePromise(of: Void.self) - let requestBody: HTTPClient.Body = .stream(length: bytesToSent) { writer -> EventLoopFuture in - func write(donePromise: EventLoopPromise) { - XCTAssertTrue(streamIsAllowedToWrite) - guard writtenBytes < bytesToSent else { - return donePromise.succeed(()) + let requestBody: HTTPClient.Body = .stream(contentLength: Int64(bytesToSent)) { + writer -> EventLoopFuture in + @Sendable func write(donePromise: EventLoopPromise) { + let futureWrite: EventLoopFuture? = testState.withLockedValue { state in + XCTAssertTrue(state.streamIsAllowedToWrite) + guard state.writtenBytes < bytesToSent else { + donePromise.succeed(()) + return nil + } + let byteCount = min(bytesToSent - state.writtenBytes, 100) + let buffer = ByteBuffer(bytes: [UInt8](repeating: 1, count: byteCount)) + state.writes += 1 + return writer.write(.byteBuffer(buffer)) } - let byteCount = min(bytesToSent - writtenBytes, 100) - let buffer = ByteBuffer(bytes: [UInt8](repeating: 1, count: byteCount)) - writes += 1 - writer.write(.byteBuffer(buffer)).whenSuccess { _ in - writtenBytes += 100 + + futureWrite?.whenSuccess { _ in + testState.withLockedValue { state in + state.writtenBytes += 100 + } write(donePromise: donePromise) } } @@ -53,20 +70,24 @@ final class RequestBagTests: XCTestCase { } var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://swift.org", method: .POST, body: requestBody)) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request(url: "https://swift.org", method: .POST, body: requestBody) + ) guard let request = maybeRequest else { return XCTFail("Expected to have a request") } let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } XCTAssert(bag.task.eventLoop === embeddedEventLoop) @@ -80,46 +101,57 @@ final class RequestBagTests: XCTestCase { executor.runRequest(bag) XCTAssertEqual(delegate.hitDidSendRequestHead, 1) - streamIsAllowedToWrite = true + testState.withLockedValue { $0.streamIsAllowedToWrite = true } bag.resumeRequestBodyStream() - streamIsAllowedToWrite = false + testState.withLockedValue { $0.streamIsAllowedToWrite = false } // after starting the body stream we should have received two writes var receivedBytes = 0 for i in 0.. EventLoopFuture in + let requestBody: HTTPClient.Body = .stream(contentLength: 12) { writer -> EventLoopFuture in writer.write(.byteBuffer(ByteBuffer(bytes: 0...3))).flatMap { _ -> EventLoopFuture in embeddedEventLoop.makeFailedFuture(TestError()) @@ -160,20 +192,24 @@ final class RequestBagTests: XCTestCase { } var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://swift.org", method: .POST, body: requestBody)) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request(url: "https://swift.org", method: .POST, body: requestBody) + ) guard let request = maybeRequest else { return XCTFail("Expected to have a request") } let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } XCTAssert(bag.task.eventLoop === embeddedEventLoop) @@ -206,20 +242,22 @@ final class RequestBagTests: XCTestCase { let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } XCTAssert(bag.eventLoop === embeddedEventLoop) let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) - bag.cancel() + bag.fail(HTTPClientError.cancelled) bag.willExecuteRequest(executor) XCTAssertTrue(executor.isCancelled, "The request bag, should call cancel immediately on the executor") @@ -228,6 +266,90 @@ final class RequestBagTests: XCTestCase { } } + func testDeadlineExceededFailsTaskEvenIfRaceBetweenCancelingSchedulerAndRequestStart() { + let embeddedEventLoop = EmbeddedEventLoop() + defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let logger = Logger(label: "test") + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://swift.org")) + guard let request = maybeRequest else { return XCTFail("Expected to have a request") } + + let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) + guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } + XCTAssert(bag.eventLoop === embeddedEventLoop) + + let queuer = MockTaskQueuer() + bag.requestWasQueued(queuer) + + let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) + XCTAssertEqual(queuer.hitCancelCount, 0) + bag.deadlineExceeded() + XCTAssertEqual(queuer.hitCancelCount, 1) + + bag.willExecuteRequest(executor) + XCTAssertTrue(executor.isCancelled, "The request bag, should call cancel immediately on the executor") + XCTAssertThrowsError(try bag.task.futureResult.wait()) { + XCTAssertEqual($0 as? HTTPClientError, .deadlineExceeded) + } + } + + func testCancelHasNoEffectAfterDeadlineExceededFailsTask() { + struct MyError: Error, Equatable {} + let embeddedEventLoop = EmbeddedEventLoop() + defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let logger = Logger(label: "test") + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://swift.org")) + guard let request = maybeRequest else { return XCTFail("Expected to have a request") } + + let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) + guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } + XCTAssert(bag.eventLoop === embeddedEventLoop) + + let queuer = MockTaskQueuer() + bag.requestWasQueued(queuer) + + XCTAssertEqual(queuer.hitCancelCount, 0) + bag.deadlineExceeded() + XCTAssertEqual(queuer.hitCancelCount, 1) + XCTAssertEqual(delegate.hitDidReceiveError, 0) + bag.fail(MyError()) + XCTAssertEqual(delegate.hitDidReceiveError, 1) + + bag.fail(HTTPClientError.cancelled) + XCTAssertEqual(delegate.hitDidReceiveError, 1) + + XCTAssertThrowsError(try bag.task.futureResult.wait()) { + XCTAssertEqualTypeAndValue($0, MyError()) + } + } + func testCancelFailsTaskAfterRequestIsSent() { let embeddedEventLoop = EmbeddedEventLoop() defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } @@ -239,15 +361,17 @@ final class RequestBagTests: XCTestCase { let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } XCTAssert(bag.eventLoop === embeddedEventLoop) @@ -261,10 +385,10 @@ final class RequestBagTests: XCTestCase { XCTAssertEqual(delegate.hitDidSendRequestHead, 1) XCTAssertEqual(delegate.hitDidSendRequest, 1) - bag.cancel() + bag.fail(HTTPClientError.cancelled) XCTAssertTrue(executor.isCancelled, "The request bag, should call cancel immediately on the executor") - XCTAssertThrowsError(try bag.task.futureResult.wait()) { + XCTAssertThrowsError(try bag.task.futureResult.timeout(after: .seconds(10)).wait()) { XCTAssertEqual($0 as? HTTPClientError, .cancelled) } } @@ -280,22 +404,24 @@ final class RequestBagTests: XCTestCase { let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } let queuer = MockTaskQueuer() bag.requestWasQueued(queuer) XCTAssertEqual(queuer.hitCancelCount, 0) - bag.cancel() + bag.fail(HTTPClientError.cancelled) XCTAssertEqual(queuer.hitCancelCount, 1) XCTAssertThrowsError(try bag.task.futureResult.wait()) { @@ -314,15 +440,17 @@ final class RequestBagTests: XCTestCase { let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) @@ -342,31 +470,35 @@ final class RequestBagTests: XCTestCase { let logger = Logger(label: "test") var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request( - url: "https://swift.org", - body: .bytes([1, 2, 3, 4, 5]) - )) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "https://swift.org", + body: .bytes([1, 2, 3, 4, 5]) + ) + ) guard let request = maybeRequest else { return XCTFail("Expected to have a request") } let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) executor.runRequest(bag) - // This simulates a race between the user cancelling the task (which invokes `RequestBag.cancel`) and the + // This simulates a race between the user cancelling the task (which invokes `RequestBag.fail(_:)`) and the // call to `resumeRequestBodyStream` (which comes from the `Channel` event loop and so may have to hop. - bag.cancel() + bag.fail(HTTPClientError.cancelled) bag.resumeRequestBodyStream() XCTAssertEqual(executor.isCancelled, true) @@ -375,6 +507,77 @@ final class RequestBagTests: XCTestCase { } } + func testDidReceiveBodyPartFailedPromise() { + let embeddedEventLoop = EmbeddedEventLoop() + defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let logger = Logger(label: "test") + + var maybeRequest: HTTPClient.Request? + + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "https://swift.org", + method: .POST, + body: .byteBuffer(.init(bytes: [1])) + ) + ) + guard let request = maybeRequest else { return XCTFail("Expected to have a request") } + + struct MyError: Error, Equatable {} + final class Delegate: HTTPClientResponseDelegate { + typealias Response = Void + let didFinishPromise: EventLoopPromise + init(didFinishPromise: EventLoopPromise) { + self.didFinishPromise = didFinishPromise + } + + func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { + task.eventLoop.makeFailedFuture(MyError()) + } + + func didReceiveError(task: HTTPClient.Task, _ error: Error) { + self.didFinishPromise.fail(error) + } + + func didFinishRequest(task: AsyncHTTPClient.HTTPClient.Task) throws { + XCTFail("\(#function) should not be called") + self.didFinishPromise.succeed(()) + } + } + let delegate = Delegate(didFinishPromise: embeddedEventLoop.makePromise()) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) + guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } + + let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) + + executor.runRequest(bag) + + bag.resumeRequestBodyStream() + XCTAssertNoThrow(try executor.receiveRequestBody { XCTAssertEqual($0, ByteBuffer(bytes: [1])) }) + + bag.receiveResponseHead(.init(version: .http1_1, status: .ok)) + + bag.succeedRequest([ByteBuffer([1])]) + + XCTAssertThrowsError(try delegate.didFinishPromise.futureResult.wait()) { error in + XCTAssertEqualTypeAndValue(error, MyError()) + } + XCTAssertThrowsError(try bag.task.futureResult.wait()) { error in + XCTAssertEqualTypeAndValue(error, MyError()) + } + } + func testHTTPUploadIsCancelledEvenThoughRequestSucceeds() { let embeddedEventLoop = EmbeddedEventLoop() defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } @@ -382,42 +585,46 @@ final class RequestBagTests: XCTestCase { var maybeRequest: HTTPClient.Request? let writeSecondPartPromise = embeddedEventLoop.makePromise(of: Void.self) - - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request( - url: "https://swift.org", - method: .POST, - headers: ["content-length": "12"], - body: .stream(length: 12) { writer -> EventLoopFuture in - var firstWriteSuccess = false - return writer.write(.byteBuffer(.init(bytes: 0...3))).flatMap { _ in - firstWriteSuccess = true - - return writeSecondPartPromise.futureResult - }.flatMap { - return writer.write(.byteBuffer(.init(bytes: 4...7))) - }.always { result in - XCTAssertTrue(firstWriteSuccess) - - guard case .failure(let error) = result else { - return XCTFail("Expected the second write to fail") + let firstWriteSuccess: NIOLockedValueBox = .init(false) + + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "https://swift.org", + method: .POST, + headers: ["content-length": "12"], + body: .stream(contentLength: 12) { writer -> EventLoopFuture in + writer.write(.byteBuffer(.init(bytes: 0...3))).flatMap { _ in + firstWriteSuccess.withLockedValue { $0 = true } + + return writeSecondPartPromise.futureResult + }.flatMap { + writer.write(.byteBuffer(.init(bytes: 4...7))) + }.always { result in + XCTAssertTrue(firstWriteSuccess.withLockedValue { $0 }) + + guard case .failure(let error) = result else { + return XCTFail("Expected the second write to fail") + } + XCTAssertEqual(error as? HTTPClientError, .requestStreamCancelled) } - XCTAssertEqual(error as? HTTPClientError, .requestStreamCancelled) } - } - )) + ) + ) guard let request = maybeRequest else { return XCTFail("Expected to have a request") } let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) @@ -453,15 +660,17 @@ final class RequestBagTests: XCTestCase { let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) @@ -509,36 +718,49 @@ final class RequestBagTests: XCTestCase { let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? var redirectTriggered = false - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: .init( + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( request: request, - redirectState: RedirectState( - .follow(max: 5, allowCycles: false), - initialURL: request.url.absoluteString - )!, - execute: { request, _ in - XCTAssertEqual(request.url.absoluteString, "https://swift.org/sswg") - XCTAssertFalse(redirectTriggered) - - let task = HTTPClient.Task(eventLoop: embeddedEventLoop, logger: logger) - task.promise.fail(HTTPClientError.cancelled) - redirectTriggered = true - return task - } - ), - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: .init( + request: request, + redirectState: RedirectState( + .follow(max: 5, allowCycles: false), + initialURL: request.url.absoluteString + )!, + execute: { request, _ in + XCTAssertEqual(request.url.absoluteString, "https://swift.org/sswg") + XCTAssertFalse(redirectTriggered) + + let task = HTTPClient.Task( + eventLoop: embeddedEventLoop, + logger: logger + ) + task.promise.fail(HTTPClientError.cancelled) + redirectTriggered = true + return task + } + ), + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) executor.runRequest(bag) XCTAssertFalse(executor.signalledDemandForResponseBody) - bag.receiveResponseHead(.init(version: .http1_1, status: .permanentRedirect, headers: ["content-length": "\(3 * 1024)", "location": "https://swift.org/sswg"])) + XCTAssertTrue(delegate.history.isEmpty) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .permanentRedirect, + headers: ["content-length": "\(3 * 1024)", "location": "https://swift.org/sswg"] + ) + bag.receiveResponseHead(responseHead) + XCTAssertEqual(delegate.history.map(\.request.url), [request.url]) + XCTAssertEqual(delegate.history.map(\.response), [responseHead]) XCTAssertNil(delegate.backpressurePromise) XCTAssertTrue(executor.signalledDemandForResponseBody) executor.resetResponseStreamDemandSignal() @@ -584,36 +806,49 @@ final class RequestBagTests: XCTestCase { let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? var redirectTriggered = false - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: .init( + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( request: request, - redirectState: RedirectState( - .follow(max: 5, allowCycles: false), - initialURL: request.url.absoluteString - )!, - execute: { request, _ in - XCTAssertEqual(request.url.absoluteString, "https://swift.org/sswg") - XCTAssertFalse(redirectTriggered) - - let task = HTTPClient.Task(eventLoop: embeddedEventLoop, logger: logger) - task.promise.fail(HTTPClientError.cancelled) - redirectTriggered = true - return task - } - ), - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: .init( + request: request, + redirectState: RedirectState( + .follow(max: 5, allowCycles: false), + initialURL: request.url.absoluteString + )!, + execute: { request, _ in + XCTAssertEqual(request.url.absoluteString, "https://swift.org/sswg") + XCTAssertFalse(redirectTriggered) + + let task = HTTPClient.Task( + eventLoop: embeddedEventLoop, + logger: logger + ) + task.promise.fail(HTTPClientError.cancelled) + redirectTriggered = true + return task + } + ), + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) executor.runRequest(bag) XCTAssertFalse(executor.signalledDemandForResponseBody) - bag.receiveResponseHead(.init(version: .http1_1, status: .permanentRedirect, headers: ["content-length": "\(4 * 1024)", "location": "https://swift.org/sswg"])) + XCTAssertTrue(delegate.history.isEmpty) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .permanentRedirect, + headers: ["content-length": "\(4 * 1024)", "location": "https://swift.org/sswg"] + ) + bag.receiveResponseHead(responseHead) + XCTAssertEqual(delegate.history.map(\.request.url), [request.url]) + XCTAssertEqual(delegate.history.map(\.response), [responseHead]) XCTAssertNil(delegate.backpressurePromise) XCTAssertFalse(executor.signalledDemandForResponseBody) XCTAssertTrue(executor.isCancelled) @@ -633,36 +868,49 @@ final class RequestBagTests: XCTestCase { let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? var redirectTriggered = false - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: .init( + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( request: request, - redirectState: RedirectState( - .follow(max: 5, allowCycles: false), - initialURL: request.url.absoluteString - )!, - execute: { request, _ in - XCTAssertEqual(request.url.absoluteString, "https://swift.org/sswg") - XCTAssertFalse(redirectTriggered) - - let task = HTTPClient.Task(eventLoop: embeddedEventLoop, logger: logger) - task.promise.fail(HTTPClientError.cancelled) - redirectTriggered = true - return task - } - ), - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: .init( + request: request, + redirectState: RedirectState( + .follow(max: 5, allowCycles: false), + initialURL: request.url.absoluteString + )!, + execute: { request, _ in + XCTAssertEqual(request.url.absoluteString, "https://swift.org/sswg") + XCTAssertFalse(redirectTriggered) + + let task = HTTPClient.Task( + eventLoop: embeddedEventLoop, + logger: logger + ) + task.promise.fail(HTTPClientError.cancelled) + redirectTriggered = true + return task + } + ), + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) executor.runRequest(bag) XCTAssertFalse(executor.signalledDemandForResponseBody) - bag.receiveResponseHead(.init(version: .http1_1, status: .permanentRedirect, headers: ["content-length": "\(3 * 1024)", "location": "https://swift.org/sswg"])) + XCTAssertTrue(delegate.history.isEmpty) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .permanentRedirect, + headers: ["content-length": "\(3 * 1024)", "location": "https://swift.org/sswg"] + ) + bag.receiveResponseHead(responseHead) + XCTAssertEqual(delegate.history.map(\.request.url), [request.url]) + XCTAssertEqual(delegate.history.map(\.response), [responseHead]) XCTAssertNil(delegate.backpressurePromise) XCTAssertTrue(executor.signalledDemandForResponseBody) executor.resetResponseStreamDemandSignal() @@ -689,85 +937,186 @@ final class RequestBagTests: XCTestCase { XCTAssertTrue(redirectTriggered) } + + func testWeDontLeakTheRequestIfTheRequestWriterWasCapturedByAPromise() { + final class LeakDetector: Sendable {} + + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } + + let httpClient = HTTPClient(eventLoopGroupProvider: .shared(group)) + defer { XCTAssertNoThrow(try httpClient.shutdown().wait()) } + + let httpBin = HTTPBin() + defer { XCTAssertNoThrow(try httpBin.shutdown()) } + + var leakDetector = LeakDetector() + + do { + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/", method: .POST) + ) + guard var request = maybeRequest else { return XCTFail("Expected to have a request here") } + + let writerPromise = group.any().makePromise(of: HTTPClient.Body.StreamWriter.self) + let donePromise = group.any().makePromise(of: Void.self) + request.body = .stream { [leakDetector] writer in + _ = leakDetector + writerPromise.succeed(writer) + return donePromise.futureResult + } + + let resultFuture = httpClient.execute(request: request) + request.body = nil + writerPromise.futureResult.whenSuccess { writer in + writer.write(.byteBuffer(ByteBuffer(string: "hello"))).map { + print("written") + }.cascade(to: donePromise) + } + XCTAssertNoThrow(try donePromise.futureResult.wait()) + print("HTTP sent") + + var result: HTTPClient.Response? + XCTAssertNoThrow(result = try resultFuture.wait()) + + XCTAssertEqual(.ok, result?.status) + let body = result?.body.map { String(buffer: $0) } + XCTAssertNotNil(body) + print("HTTP done") + } + XCTAssertTrue(isKnownUniquelyReferenced(&leakDetector)) + } +} + +extension HTTPClient.Task { + convenience init( + eventLoop: EventLoop, + logger: Logger + ) { + self.init(eventLoop: eventLoop, logger: logger, tracing: .init()) { + preconditionFailure("thread pool not needed in tests") + } + } } -class UploadCountingDelegate: HTTPClientResponseDelegate { +final class UploadCountingDelegate: HTTPClientResponseDelegate { typealias Response = Void let eventLoop: EventLoop - private(set) var hitDidSendRequestHead = 0 - private(set) var hitDidSendRequestPart = 0 - private(set) var hitDidSendRequest = 0 - private(set) var hitDidReceiveResponse = 0 - private(set) var hitDidReceiveBodyPart = 0 - private(set) var hitDidReceiveError = 0 + struct State: Sendable { + var hitDidSendRequestHead = 0 + var hitDidSendRequestPart = 0 + var hitDidSendRequest = 0 + var hitDidReceiveResponse = 0 + var hitDidReceiveBodyPart = 0 + var hitDidReceiveError = 0 + + var history: [(request: HTTPClient.Request, response: HTTPResponseHead)] = [] + var receivedHead: HTTPResponseHead? + var lastBodyPart: ByteBuffer? + var backpressurePromise: EventLoopPromise? + var lastError: Error? + } + + private let state: NIOLoopBoundBox - private(set) var receivedHead: HTTPResponseHead? - private(set) var lastBodyPart: ByteBuffer? - private(set) var backpressurePromise: EventLoopPromise? - private(set) var lastError: Error? + var hitDidSendRequestHead: Int { self.state.value.hitDidSendRequestHead } + var hitDidSendRequestPart: Int { self.state.value.hitDidSendRequestPart } + var hitDidSendRequest: Int { self.state.value.hitDidSendRequest } + var hitDidReceiveResponse: Int { self.state.value.hitDidReceiveResponse } + var hitDidReceiveBodyPart: Int { self.state.value.hitDidReceiveBodyPart } + var hitDidReceiveError: Int { self.state.value.hitDidReceiveError } + + var history: [(request: HTTPClient.Request, response: HTTPResponseHead)] { + self.state.value.history + } + var receivedHead: HTTPResponseHead? { self.state.value.receivedHead } + var lastBodyPart: ByteBuffer? { self.state.value.lastBodyPart } + var backpressurePromise: EventLoopPromise? { self.state.value.backpressurePromise } + var lastError: Error? { self.state.value.lastError } init(eventLoop: EventLoop) { self.eventLoop = eventLoop + self.state = .makeBoxSendingValue(State(), eventLoop: eventLoop) } func didSendRequestHead(task: HTTPClient.Task, _ head: HTTPRequestHead) { - self.hitDidSendRequestHead += 1 + self.state.value.hitDidSendRequestHead += 1 } func didSendRequestPart(task: HTTPClient.Task, _ part: IOData) { - self.hitDidSendRequestPart += 1 + self.state.value.hitDidSendRequestPart += 1 } func didSendRequest(task: HTTPClient.Task) { - self.hitDidSendRequest += 1 + self.state.value.hitDidSendRequest += 1 + } + + func didVisitURL(task: HTTPClient.Task, _ request: HTTPClient.Request, _ head: HTTPResponseHead) { + self.state.value.history.append((request, head)) } func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { - self.receivedHead = head + self.state.value.receivedHead = head return self.createBackpressurePromise() } func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { - assert(self.backpressurePromise == nil) - self.hitDidReceiveBodyPart += 1 - self.lastBodyPart = buffer + assert(self.state.value.backpressurePromise == nil) + self.state.value.hitDidReceiveBodyPart += 1 + self.state.value.lastBodyPart = buffer return self.createBackpressurePromise() } func didFinishRequest(task: HTTPClient.Task) throws { - self.hitDidReceiveResponse += 1 + self.state.value.hitDidReceiveResponse += 1 } func didReceiveError(task: HTTPClient.Task, _ error: Error) { - self.hitDidReceiveError += 1 - self.lastError = error + self.state.value.hitDidReceiveError += 1 + self.state.value.lastError = error } private func createBackpressurePromise() -> EventLoopFuture { - assert(self.backpressurePromise == nil) - self.backpressurePromise = self.eventLoop.makePromise(of: Void.self) - return self.backpressurePromise!.futureResult.always { _ in - self.backpressurePromise = nil + assert(self.state.value.backpressurePromise == nil) + self.state.value.backpressurePromise = self.eventLoop.makePromise(of: Void.self) + return self.state.value.backpressurePromise!.futureResult.always { _ in + self.state.value.backpressurePromise = nil } } } -class MockTaskQueuer: HTTPRequestScheduler { - private(set) var hitCancelCount = 0 +final class MockTaskQueuer: HTTPRequestScheduler { + private let _hitCancelCount = ManagedAtomic(0) + + var hitCancelCount: Int { + self._hitCancelCount.load(ordering: .sequentiallyConsistent) + } - init() {} + let onCancelRequest: (@Sendable (HTTPSchedulableRequest) -> Void)? + + init(onCancelRequest: (@Sendable (HTTPSchedulableRequest) -> Void)? = nil) { + self.onCancelRequest = onCancelRequest + } - func cancelRequest(_: HTTPSchedulableRequest) { - self.hitCancelCount += 1 + func cancelRequest(_ request: HTTPSchedulableRequest) { + self._hitCancelCount.wrappingIncrement(ordering: .sequentiallyConsistent) + self.onCancelRequest?(request) } } extension RequestOptions { - static func forTests(idleReadTimeout: TimeAmount? = nil) -> Self { + static func forTests( + idleReadTimeout: TimeAmount? = nil, + idleWriteTimeout: TimeAmount? = nil, + dnsOverride: [String: String] = [:] + ) -> Self { RequestOptions( - idleReadTimeout: idleReadTimeout + idleReadTimeout: idleReadTimeout, + idleWriteTimeout: idleWriteTimeout, + dnsOverride: dnsOverride ) } } diff --git a/Tests/AsyncHTTPClientTests/RequestValidationTests+XCTest.swift b/Tests/AsyncHTTPClientTests/RequestValidationTests+XCTest.swift deleted file mode 100644 index 3a93d70ec..000000000 --- a/Tests/AsyncHTTPClientTests/RequestValidationTests+XCTest.swift +++ /dev/null @@ -1,52 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// RequestValidationTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension RequestValidationTests { - static var allTests: [(String, (RequestValidationTests) -> () throws -> Void)] { - return [ - ("testContentLengthHeaderIsRemovedFromGETIfNoBody", testContentLengthHeaderIsRemovedFromGETIfNoBody), - ("testContentLengthHeaderIsAddedToPOSTAndPUTWithNoBody", testContentLengthHeaderIsAddedToPOSTAndPUTWithNoBody), - ("testContentLengthHeaderIsChangedIfBodyHasDifferentLength", testContentLengthHeaderIsChangedIfBodyHasDifferentLength), - ("testTRACERequestMustNotHaveBody", testTRACERequestMustNotHaveBody), - ("testGET_HEAD_DELETE_CONNECTRequestCanHaveBody", testGET_HEAD_DELETE_CONNECTRequestCanHaveBody), - ("testInvalidHeaderFieldNames", testInvalidHeaderFieldNames), - ("testValidHeaderFieldNames", testValidHeaderFieldNames), - ("testMetadataDetectConnectionClose", testMetadataDetectConnectionClose), - ("testMetadataDefaultIsConnectionCloseIsFalse", testMetadataDefaultIsConnectionCloseIsFalse), - ("testNoHeadersNoBody", testNoHeadersNoBody), - ("testNoHeadersHasBody", testNoHeadersHasBody), - ("testContentLengthHeaderNoBody", testContentLengthHeaderNoBody), - ("testContentLengthHeaderHasBody", testContentLengthHeaderHasBody), - ("testTransferEncodingHeaderNoBody", testTransferEncodingHeaderNoBody), - ("testTransferEncodingHeaderHasBody", testTransferEncodingHeaderHasBody), - ("testBothHeadersNoBody", testBothHeadersNoBody), - ("testBothHeadersHasBody", testBothHeadersHasBody), - ("testHostHeaderIsSetCorrectlyInCreateRequestHead", testHostHeaderIsSetCorrectlyInCreateRequestHead), - ("testTraceMethodIsNotAllowedToHaveAFixedLengthBody", testTraceMethodIsNotAllowedToHaveAFixedLengthBody), - ("testTraceMethodIsNotAllowedToHaveADynamicLengthBody", testTraceMethodIsNotAllowedToHaveADynamicLengthBody), - ("testTransferEncodingsAreOverwrittenIfBodyLengthIsFixed", testTransferEncodingsAreOverwrittenIfBodyLengthIsFixed), - ("testTransferEncodingsAreOverwrittenIfBodyLengthIsDynamic", testTransferEncodingsAreOverwrittenIfBodyLengthIsDynamic), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/RequestValidationTests.swift b/Tests/AsyncHTTPClientTests/RequestValidationTests.swift index c50d3afd1..ea5a6bd66 100644 --- a/Tests/AsyncHTTPClientTests/RequestValidationTests.swift +++ b/Tests/AsyncHTTPClientTests/RequestValidationTests.swift @@ -12,11 +12,12 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOHTTP1 import XCTest +@testable import AsyncHTTPClient + class RequestValidationTests: XCTestCase { func testContentLengthHeaderIsRemovedFromGETIfNoBody() { var headers = HTTPHeaders([("Content-Length", "0")]) @@ -29,13 +30,17 @@ class RequestValidationTests: XCTestCase { func testContentLengthHeaderIsAddedToPOSTAndPUTWithNoBody() { var putHeaders = HTTPHeaders() var putMetadata: RequestFramingMetadata? - XCTAssertNoThrow(putMetadata = try putHeaders.validateAndSetTransportFraming(method: .PUT, bodyLength: .known(0))) + XCTAssertNoThrow( + putMetadata = try putHeaders.validateAndSetTransportFraming(method: .PUT, bodyLength: .known(0)) + ) XCTAssertEqual(putHeaders.first(name: "Content-Length"), "0") XCTAssertEqual(putMetadata?.body, .fixedSize(0)) var postHeaders = HTTPHeaders() var postMetadata: RequestFramingMetadata? - XCTAssertNoThrow(postMetadata = try postHeaders.validateAndSetTransportFraming(method: .POST, bodyLength: .known(0))) + XCTAssertNoThrow( + postMetadata = try postHeaders.validateAndSetTransportFraming(method: .POST, bodyLength: .known(0)) + ) XCTAssertEqual(postHeaders.first(name: "Content-Length"), "0") XCTAssertEqual(postMetadata?.body, .fixedSize(0)) } @@ -90,7 +95,7 @@ class RequestValidationTests: XCTestCase { func testMetadataDetectConnectionClose() { var headers = HTTPHeaders([ - ("Connection", "close"), + ("Connection", "close") ]) var metadata: RequestFramingMetadata? XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: .GET, bodyLength: .known(0))) @@ -114,7 +119,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.GET, .HEAD, .DELETE, .CONNECT, .TRACE] { var headers: HTTPHeaders = .init() var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0))) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0)) + ) XCTAssertTrue(headers["content-length"].isEmpty) XCTAssertTrue(headers["transfer-encoding"].isEmpty) XCTAssertEqual(metadata?.body, .fixedSize(0)) @@ -123,7 +130,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.POST, .PUT] { var headers: HTTPHeaders = .init() var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0))) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0)) + ) XCTAssertEqual(headers["content-length"].first, "0") XCTAssertFalse(headers["transfer-encoding"].contains("chunked")) XCTAssertEqual(metadata?.body, .fixedSize(0)) @@ -139,7 +148,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.GET, .HEAD, .DELETE, .CONNECT] { var headers: HTTPHeaders = .init() var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(1))) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(1)) + ) XCTAssertEqual(headers["content-length"].first, "1") XCTAssertTrue(headers["transfer-encoding"].isEmpty) XCTAssertEqual(metadata?.body, .fixedSize(1)) @@ -149,7 +160,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.GET, .HEAD, .DELETE, .CONNECT] { var headers: HTTPHeaders = .init() var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .unknown)) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .unknown) + ) XCTAssertTrue(headers["content-length"].isEmpty) XCTAssertTrue(headers["transfer-encoding"].contains("chunked")) XCTAssertEqual(metadata?.body, .stream) @@ -159,7 +172,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.POST, .PUT] { var headers: HTTPHeaders = .init() var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(1))) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(1)) + ) XCTAssertEqual(headers["content-length"].first, "1") XCTAssertTrue(headers["transfer-encoding"].isEmpty) XCTAssertEqual(metadata?.body, .fixedSize(1)) @@ -169,7 +184,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.POST, .PUT] { var headers: HTTPHeaders = .init() var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .unknown)) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .unknown) + ) XCTAssertTrue(headers["content-length"].isEmpty) XCTAssertTrue(headers["transfer-encoding"].contains("chunked")) XCTAssertEqual(metadata?.body, .stream) @@ -184,7 +201,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.GET, .HEAD, .DELETE, .CONNECT, .TRACE] { var headers: HTTPHeaders = .init([("Content-Length", "1")]) var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0))) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0)) + ) XCTAssertTrue(headers["content-length"].isEmpty) XCTAssertTrue(headers["transfer-encoding"].isEmpty) XCTAssertEqual(metadata?.body, .fixedSize(0)) @@ -193,7 +212,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.POST, .PUT] { var headers: HTTPHeaders = .init([("Content-Length", "1")]) var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0))) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0)) + ) XCTAssertEqual(headers["content-length"].first, "0") XCTAssertTrue(headers["transfer-encoding"].isEmpty) XCTAssertEqual(metadata?.body, .fixedSize(0)) @@ -208,7 +229,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.GET, .HEAD, .DELETE, .CONNECT] { var headers: HTTPHeaders = .init([("Content-Length", "1")]) var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(1))) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(1)) + ) XCTAssertEqual(headers["content-length"].first, "1") XCTAssertTrue(headers["transfer-encoding"].isEmpty) XCTAssertEqual(metadata?.body, .fixedSize(1)) @@ -217,7 +240,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.POST, .PUT] { var headers: HTTPHeaders = .init([("Content-Length", "1")]) var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(1))) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(1)) + ) XCTAssertEqual(headers["content-length"].first, "1") XCTAssertTrue(headers["transfer-encoding"].isEmpty) XCTAssertEqual(metadata?.body, .fixedSize(1)) @@ -232,7 +257,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.GET, .HEAD, .DELETE, .CONNECT, .TRACE] { var headers: HTTPHeaders = .init([("Transfer-Encoding", "chunked")]) var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0))) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0)) + ) XCTAssertTrue(headers["content-length"].isEmpty) XCTAssertFalse(headers["transfer-encoding"].contains("chunked")) XCTAssertEqual(metadata?.body, .fixedSize(0)) @@ -241,7 +268,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.POST, .PUT] { var headers: HTTPHeaders = .init([("Transfer-Encoding", "chunked")]) var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0))) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0)) + ) XCTAssertEqual(headers["content-length"].first, "0") XCTAssertFalse(headers["transfer-encoding"].contains("chunked")) XCTAssertEqual(metadata?.body, .fixedSize(0)) @@ -337,21 +366,27 @@ class RequestValidationTests: XCTestCase { func testTransferEncodingsAreOverwrittenIfBodyLengthIsFixed() { var headers: HTTPHeaders = [ - "Transfer-Encoding": "gzip, chunked", + "Transfer-Encoding": "gzip, chunked" ] XCTAssertNoThrow(try headers.validateAndSetTransportFraming(method: .POST, bodyLength: .known(1))) - XCTAssertEqual(headers, [ - "Content-Length": "1", - ]) + XCTAssertEqual( + headers, + [ + "Content-Length": "1" + ] + ) } func testTransferEncodingsAreOverwrittenIfBodyLengthIsDynamic() { var headers: HTTPHeaders = [ - "Transfer-Encoding": "gzip, chunked", + "Transfer-Encoding": "gzip, chunked" ] XCTAssertNoThrow(try headers.validateAndSetTransportFraming(method: .POST, bodyLength: .unknown)) - XCTAssertEqual(headers, [ - "Transfer-Encoding": "chunked", - ]) + XCTAssertEqual( + headers, + [ + "Transfer-Encoding": "chunked" + ] + ) } } diff --git a/Tests/AsyncHTTPClientTests/Resources/example.com.cert.pem b/Tests/AsyncHTTPClientTests/Resources/example.com.cert.pem new file mode 100644 index 000000000..f16590cde --- /dev/null +++ b/Tests/AsyncHTTPClientTests/Resources/example.com.cert.pem @@ -0,0 +1,12 @@ +-----BEGIN CERTIFICATE----- +MIIBwTCCAUigAwIBAgIUX7f9BABxGdAqG5EvLpQScFt9lOkwCgYIKoZIzj0EAwMw +KjEUMBIGA1UECgwLU2VsZiBTaWduZWQxEjAQBgNVBAMMCWxvY2FsaG9zdDAeFw0y +NTA0MDExNDMwMTFaFw0yNjA0MDExNDMwMTFaMCoxFDASBgNVBAoMC1NlbGYgU2ln +bmVkMRIwEAYDVQQDDAlsb2NhbGhvc3QwdjAQBgcqhkjOPQIBBgUrgQQAIgNiAAQW +szfO5HCWIWgKUqyXUU0pFpYgaq01RRL69XZz1CkV6XTrxMfIvvwez2886EQDL8QX +i5NpKg3qvPgWuDjVHaj4WEJe5XMNqcujxcTufBlmaQ6o4vtoK7CIHDIDldF/HRij +LzAtMBYGA1UdEQQPMA2CC2V4YW1wbGUuY29tMBMGA1UdJQQMMAoGCCsGAQUFBwMB +MAoGCCqGSM49BAMDA2cAMGQCMBJ8Dxg0qX2bEZ3r6dI3UCGAUYxJDVk+XhiIY1Fm +5FJeQqhaVayCRPrPXXGZUJGY/wIwXej70FwkxHKLq+XxfHTC5CzmoOK469C9Rk9Y +ucddXM83ebFxVNgRCWetH9tDdXJ9 +-----END CERTIFICATE----- \ No newline at end of file diff --git a/Tests/AsyncHTTPClientTests/Resources/example.com.private-key.pem b/Tests/AsyncHTTPClientTests/Resources/example.com.private-key.pem new file mode 100644 index 000000000..3ad9ce79e --- /dev/null +++ b/Tests/AsyncHTTPClientTests/Resources/example.com.private-key.pem @@ -0,0 +1,6 @@ +-----BEGIN PRIVATE KEY----- +MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDD9v51MTOcgFIbiHbok +U+QOubosGF1u1q+D3fEUb1U2cgjCofKmPHekXTz0xu9MJi2hZANiAAQWszfO5HCW +IWgKUqyXUU0pFpYgaq01RRL69XZz1CkV6XTrxMfIvvwez2886EQDL8QXi5NpKg3q +vPgWuDjVHaj4WEJe5XMNqcujxcTufBlmaQ6o4vtoK7CIHDIDldF/HRg= +-----END PRIVATE KEY----- \ No newline at end of file diff --git a/Tests/AsyncHTTPClientTests/Resources/self_signed_cert.pem b/Tests/AsyncHTTPClientTests/Resources/self_signed_cert.pem new file mode 100644 index 000000000..20b46f355 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/Resources/self_signed_cert.pem @@ -0,0 +1,27 @@ +-----BEGIN CERTIFICATE----- +MIIEpjCCAo4CCQCeTmfiTQcJrzANBgkqhkiG9w0BAQsFADAUMRIwEAYDVQQDDAls +b2NhbGhvc3QwIBcNMjIwNjE0MTI1NDQ4WhgPMjI5NjAzMjgxMjU0NDhaMBQxEjAQ +BgNVBAMMCWxvY2FsaG9zdDCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIB +AK16gPDwP/Xbaf36x5BNd6yHDxCPIIJP4JLfMEuozwLE0YRqwmZOuklb4jUbAXf7 +u9u24ANrC4XS6VVWkfPdugokAUkaKPpwkV4GOiMCXeSDjDiLt1dYxlbp+MLV78a5 +oUDbCAqfFKebIgv1oiK+L6/p818eAHSBWEXXMhTeBDEQAIpJLTG88iVu6r3fMJeH +FbMWuPmAajmx2AEGmwD1x6+NHZLJv1zaufa7j0sHADraagXnfKn6rkLn1is6QFu4 +v7xaNlEwsRCYbh0nrtCtEdJIqnEHc0GCu/gnw5GE3CuRG3FYBTZStIF7d9h+XZQB +ky/YEWSGw9DXFBbebOZugopvl91qaZLqo6Wg0J8qCodgFtJHOSVMq/SAOBmKyw+b +7FYZbj4tQKpuuhwCN+gwEveTy+BK+zGY/sVzPwR8PNjpCgT/HiOBM7dNt4+2r9pY +Ld/mcMvakgRzM4Iqqntem9ltuckZev0TRjdrIylVWsAlNYVXm4ncMLkbzxFkv5Gb +AlhAuTwxyFkIo0M7+GS4lXCZ2bX2umJ0DTl3/NGJserFdkOhvHZSHHC9BzDBysmc +SejX/cGOFQ8O3sFeJdVMGlO64dU482O0FbBcLHmTLXWR4t8dlhrzJuXZ4X6WtHqY +83RwyD1gacYRZnT0eL+Z7XGrO1/qypji1RNaFIaGUt7DAgMBAAEwDQYJKoZIhvcN +AQELBQADggIBAIigOuEVirgqXoUMStTwYObs/DcNIPEugn9gAq9Lt1cr6fm7CvhG +AupxoJTbKLHQX6FegvFSA+4Kt3KYXX9Qi9SJF3Vr4zOhV0q203d4Aui6Lamo5Yye +nhbzzXuDSIyxpaWPFRC2RqCA6+hV8/Ar9Bx0TCI4NQxWxQEPerwqzqWCuTbViccw +WzlwRD2AHibaQaCbpzXg9lOX0fRJHoSM3exYQd91pDoSoL3f/EV3I/czssq+10M8 +F4GhE4bQjaKD7jL5U59dlvfy73nLAzzxzsxsFuYTAgzZwDg586sdbrqqFjzjoZ9A +dF8NuVYkHyFDQkpe66e1isNZi7eFdSjeVmj8llp4b6in59ik7ZS7arzGOxhZZzmv +Jf3nfE4hJzMS/4GJsKMdtcI+6K+hMi6Yt9OoPh82SQ2q8gK4QSWWrwAKuQ4F4UeO +pgiWBryKrkOXlGARBbsR/ZDhlqyAskeGuhIpEY5NLCByFfQ5KlcrX+n4TVLRZMvb +/7PZqboGgU+CUVawm/suPAs8jOlFQOzrxWQPRfWVvFII62ABgozS8N/xZ/WbgTVj +kOtWj85NpaBSCUliIY/7z1FkjpMZO8Kds45WQzAq4YChDLZGbgV0MkyXqO/LEYFJ +zqGOP1yGxVcKxu6t8Xh0hL6JPFmKWiMEWVrd1wut6NAIu6WNftmWZX6J +-----END CERTIFICATE----- diff --git a/Tests/AsyncHTTPClientTests/Resources/self_signed_key.pem b/Tests/AsyncHTTPClientTests/Resources/self_signed_key.pem new file mode 100644 index 000000000..8811c2d81 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/Resources/self_signed_key.pem @@ -0,0 +1,52 @@ +-----BEGIN PRIVATE KEY----- +MIIJQgIBADANBgkqhkiG9w0BAQEFAASCCSwwggkoAgEAAoICAQCteoDw8D/122n9 ++seQTXeshw8QjyCCT+CS3zBLqM8CxNGEasJmTrpJW+I1GwF3+7vbtuADawuF0ulV +VpHz3boKJAFJGij6cJFeBjojAl3kg4w4i7dXWMZW6fjC1e/GuaFA2wgKnxSnmyIL +9aIivi+v6fNfHgB0gVhF1zIU3gQxEACKSS0xvPIlbuq93zCXhxWzFrj5gGo5sdgB +BpsA9cevjR2Syb9c2rn2u49LBwA62moF53yp+q5C59YrOkBbuL+8WjZRMLEQmG4d +J67QrRHSSKpxB3NBgrv4J8ORhNwrkRtxWAU2UrSBe3fYfl2UAZMv2BFkhsPQ1xQW +3mzmboKKb5fdammS6qOloNCfKgqHYBbSRzklTKv0gDgZissPm+xWGW4+LUCqbroc +AjfoMBL3k8vgSvsxmP7Fcz8EfDzY6QoE/x4jgTO3TbePtq/aWC3f5nDL2pIEczOC +Kqp7XpvZbbnJGXr9E0Y3ayMpVVrAJTWFV5uJ3DC5G88RZL+RmwJYQLk8MchZCKND +O/hkuJVwmdm19rpidA05d/zRibHqxXZDobx2UhxwvQcwwcrJnEno1/3BjhUPDt7B +XiXVTBpTuuHVOPNjtBWwXCx5ky11keLfHZYa8ybl2eF+lrR6mPN0cMg9YGnGEWZ0 +9Hi/me1xqztf6sqY4tUTWhSGhlLewwIDAQABAoICAApRcP3jrEo5RLKgieIhWW7f +kZvQh4R4r8jMkZjOb5Gglz2jA/EF2bqnRmsWMh4q0N+envBVG5hYFRzIS2IP3BLi +VVk9vxY2P88x259dcqw2zs5GMR923kUpIWylQN+3BspOvMm08IuPhJTlhUE/wqJZ +7enIZQqI7vEofYgUNHeelgmjlJaSwGxNjpTAg6lflYDTZykf5DGOTGSzOeDyvW/J +muqyKTmioND2Eu3JetAFUa0MObP6fwbntytXCaDq+ix/yR9HICD2kAYX6CPtR1QU +kl6qrMZGultmMhGjr1zAArvZGmZCwQ26hERSL8qv1UtRNKegBGGViVJa5GtIQ2dT +UmTWmWu/5gyxKvvjuqYl8Dub2/ZT0iGAsA6hGyUr+vpgjcNEZqsYhiEiQPi0g1sM +XyszytqG1F7JzXYgVzcdFA9L+eLD+i4nKD18TYTYHFGRmxwQ+HzHnetgDQ2gqRbB +XwT4lp643oNLMGyL+T0cQ7i1Hpq7Ko0S2FeXzzFe9B33uXbDvc0usier5qx2tgxc +zfgSqJjahfo4LCxhxvBWOup3U/sXNgyMCctr1qjpwGwLek+H1keOyv7FO9O6OgI1 +v5ZPFsJV7mK1fDLM/8QLDpUcUNnhPUfzsBdxKrjLfnZ8MPNczgv1GPzb4jsLvewf +g6ps8oBwnZDQVa6dMuyRAoIBAQDnTKRUsTMmFo01o0k90C8SwwE2x7Wry8r6vIIf +PMni3ZAS+zWFnu1zg82+83QpdvskntWM2iXS7nimmkXClCCFMDU/hYA9EsZtGIv6 ++xA6gYF0Xd3Qf9QrvhixOxHj3ixNyCeee3/9XUYln3ZfEx8cgCwHjPSIm3rOKI2M +PFnuG9xJ513sy6YCDrCdtb661E6bmsaMcIhu6S7At0njwnoL9aB617TSds5tFEr8 +74EW3D9epN01uUQ9MgZSXbzdQ82IswLps4a/k4wfDFp4qKpx7zOsoTSjA9il3fgW +QLhBXxnzTYYTvwxIgaW//fyqEL3p6t9zuYcjbORcrj7v8xIvAoIBAQDAASGjsSCA +hn03DXrI/atoXEC0htVwPwp4HTI0Z1/rOS0IrFBcX3CWx90Dr/clePHQGPk1yOO7 +oM83zumwggIOymtDhlTcCa77yN9x9AZMW3qPMF+mvAouUzItnlMrOjvfEnIWziWC +UsylBiV4/I6tf0zpH8zFYPNXq98fpv+UXyJDTW+YGBc2b2BwZZA6RdtFalqvunM7 +M8FIH8vSYEMR0YC47L2ceBJY/U9EQpsc6vuS7+CoXOH/WRb5v1z+a5O9sHWp8Rdc +Oh67B6v2feUT9TwhGUVF0L+ktW389e3N+VzPvbvICvRsOvo6+bceCJTszhNno00s +87bPyelaHXutAoIBAFtJ6onqri9YMz96RMv6wLl88Zu3UsKNWn1/rTO7AEtj+xsi +vssQINO4r5mv6Kb86L5ZWhuPdeI8cK4AsYvMftFSZ5G8lRKFuH8Scx0Jviv5NSjC +a2uBKDJjgsdgcv0mkQHZ/5kTUT6kc60htMxtdZgAFmCch17rTprTcppor23E3Trl +8DInZkvllFuKgc6nQKc1fSustoxfyC4TqTwVY6oYtdAGFr4CWhK/MaGGvcJSB0jJ +dO1hQ8eLWOdlS8dgnVxYmsu2KXavO1x9ua9pkmwJZrG5pla4i+dbJjFSNebHLCzU +6hgdDTIIyWxvSCuvE+Wg57R7AxU+Qxs5Qmnd280CggEAex4+m+BwnvmeQTb7jPZc +e0bsltX+90L1S6AtGT1QXF0Fa5JS1Wi9oXH3Xu3u5LBxHqdk5gAzR5UOSxL69pvn +BeT2cw4oTBBJjFp6LW/0ufHO3RJ/w0LApIPkoSvs2MM2sQv67HSzyKWfZBJU5QfN +1aLTholFnStV3tnu8TT8nf+C0PVOoZCREe7JQElf+n3g5NoV3KkKSuQdBEqfP/9K +Apr8l5f23eaAnV+Q/IxZOmnTd50pycwFft95xBvZXatNyUzlpltaR2FdY0DAHAcO +ZYXTUMYLjYEV4mAUbyijnHhR80QOrW+Y2+3VlwuZSEDofhCGkOY+Dp0YlJU8dPSC +4QKCAQEA3qlwsjJT8Vv+Sx2sZySDNtey/00yjxe4TckdH8dWTC22G38/ppbazlp/ +YVqFUgteoo5FI60HKa0+jLKQZpCIH6JgpL3yq81Obl0wRkrSNePRTL1Ikiff8P2j +bowFpbIdJLgDco1opJpDgTOz2mB7HlHu6RyoKjiVrNA/EOoks1Uljxdth6h/5ctr +rLn8dnz2sTtwxcUsOpyFcFQ2qaWJvSg+bF7JPPzMrpQfCR1qVWa43Kl8KlcWSKaq +ITpglIBY+h3F2GygAAcnpfkXde381Iw89y7TFd2LxWQR98zhnbJWF2JmuuPDtVRv ++HYZkcyQcpDwfC+2NOWOU7NQj+IDIA== +-----END PRIVATE KEY----- diff --git a/Tests/AsyncHTTPClientTests/ResponseDelayGetTests.swift b/Tests/AsyncHTTPClientTests/ResponseDelayGetTests.swift new file mode 100644 index 000000000..5fd1d6720 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/ResponseDelayGetTests.swift @@ -0,0 +1,46 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2022 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import AsyncHTTPClient +import Atomics +import Logging +import NIOConcurrencyHelpers +import NIOCore +import NIOFoundationCompat +import NIOHTTP1 +import NIOHTTPCompression +import NIOPosix +import NIOSSL +import NIOTestUtils +import NIOTransportServices +import XCTest + +#if canImport(Network) +import Network +#endif + +final class ResponseDelayGetTests: XCTestCaseHTTPClientTestsBaseClass { + func testResponseDelayGet() throws { + let req = try HTTPClient.Request( + url: self.defaultHTTPBinURLPrefix + "get", + method: .GET, + headers: ["X-internal-delay": "2000"], + body: nil + ) + let start = NIODeadline.now() + let response = try self.defaultClient.execute(request: req).wait() + XCTAssertGreaterThanOrEqual(.now() - start, .milliseconds(1_900)) + XCTAssertEqual(response.status, .ok) + } +} diff --git a/Tests/AsyncHTTPClientTests/SOCKSEventsHandlerTests+XCTest.swift b/Tests/AsyncHTTPClientTests/SOCKSEventsHandlerTests+XCTest.swift deleted file mode 100644 index 0338adf3c..000000000 --- a/Tests/AsyncHTTPClientTests/SOCKSEventsHandlerTests+XCTest.swift +++ /dev/null @@ -1,35 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// SOCKSEventsHandlerTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension SOCKSEventsHandlerTests { - static var allTests: [(String, (SOCKSEventsHandlerTests) -> () throws -> Void)] { - return [ - ("testHandlerHappyPath", testHandlerHappyPath), - ("testHandlerFailsFutureWhenRemovedWithoutEvent", testHandlerFailsFutureWhenRemovedWithoutEvent), - ("testHandlerFailsFutureWhenHandshakeFails", testHandlerFailsFutureWhenHandshakeFails), - ("testHandlerClosesConnectionIfHandshakeTimesout", testHandlerClosesConnectionIfHandshakeTimesout), - ("testHandlerWorksIfDeadlineIsInPast", testHandlerWorksIfDeadlineIsInPast), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/SOCKSEventsHandlerTests.swift b/Tests/AsyncHTTPClientTests/SOCKSEventsHandlerTests.swift index 066a631a5..2352c6c1c 100644 --- a/Tests/AsyncHTTPClientTests/SOCKSEventsHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/SOCKSEventsHandlerTests.swift @@ -12,12 +12,13 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOEmbedded import NIOSOCKS import XCTest +@testable import AsyncHTTPClient + class SOCKSEventsHandlerTests: XCTestCase { func testHandlerHappyPath() { let socksEventsHandler = SOCKSEventsHandler(deadline: .now() + .seconds(10)) @@ -37,7 +38,7 @@ class SOCKSEventsHandlerTests: XCTestCase { let embedded = EmbeddedChannel(handlers: [socksEventsHandler]) XCTAssertNotNil(socksEventsHandler.socksEstablishedFuture) - XCTAssertNoThrow(try embedded.pipeline.removeHandler(socksEventsHandler).wait()) + XCTAssertNoThrow(try embedded.pipeline.syncOperations.removeHandler(socksEventsHandler).wait()) XCTAssertThrowsError(try XCTUnwrap(socksEventsHandler.socksEstablishedFuture).wait()) } diff --git a/Tests/AsyncHTTPClientTests/SOCKSTestUtils.swift b/Tests/AsyncHTTPClientTests/SOCKSTestUtils.swift index d7c97e6fe..50d26b278 100644 --- a/Tests/AsyncHTTPClientTests/SOCKSTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/SOCKSTestUtils.swift @@ -40,25 +40,40 @@ class MockSOCKSServer { self.channel.localAddress!.port! } - init(expectedURL: String, expectedResponse: String, misbehave: Bool = false, file: String = #file, line: UInt = #line) throws { + init( + expectedURL: String, + expectedResponse: String, + misbehave: Bool = false, + file: String = #filePath, + line: UInt = #line + ) throws { let elg = MultiThreadedEventLoopGroup(numberOfThreads: 1) let bootstrap: ServerBootstrap if misbehave { bootstrap = ServerBootstrap(group: elg) .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) .childChannelInitializer { channel in - channel.pipeline.addHandler(TestSOCKSBadServerHandler()) + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(TestSOCKSBadServerHandler()) + } } } else { bootstrap = ServerBootstrap(group: elg) .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) .childChannelInitializer { channel in - let handshakeHandler = SOCKSServerHandshakeHandler() - return channel.pipeline.addHandlers([ - handshakeHandler, - SOCKSTestHandler(handshakeHandler: handshakeHandler), - TestHTTPServer(expectedURL: expectedURL, expectedResponse: expectedResponse, file: file, line: line), - ]) + channel.eventLoop.makeCompletedFuture { + let handshakeHandler = SOCKSServerHandshakeHandler() + try channel.pipeline.syncOperations.addHandlers([ + handshakeHandler, + SOCKSTestHandler(handshakeHandler: handshakeHandler), + TestHTTPServer( + expectedURL: expectedURL, + expectedResponse: expectedResponse, + file: file, + line: line + ), + ]) + } } } self.channel = try bootstrap.bind(host: "localhost", port: 0).wait() @@ -86,19 +101,34 @@ class SOCKSTestHandler: ChannelInboundHandler, RemovableChannelHandler { let message = self.unwrapInboundIn(data) switch message { case .greeting: - context.writeAndFlush(.init( - ServerMessage.selectedAuthenticationMethod(.init(method: .noneRequired))), promise: nil) + context.writeAndFlush( + .init( + ServerMessage.selectedAuthenticationMethod(.init(method: .noneRequired)) + ), + promise: nil + ) case .authenticationData: context.fireErrorCaught(MockSOCKSError(description: "Received authentication data but didn't receive any.")) case .request(let request): - context.writeAndFlush(.init( - ServerMessage.response(.init(reply: .succeeded, boundAddress: request.addressType))), promise: nil) - context.channel.pipeline.addHandlers([ - ByteToMessageHandler(HTTPRequestDecoder()), - HTTPResponseEncoder(), - ], position: .after(self)).whenSuccess { - context.channel.pipeline.removeHandler(self, promise: nil) - context.channel.pipeline.removeHandler(self.handshakeHandler, promise: nil) + context.writeAndFlush( + .init( + ServerMessage.response(.init(reply: .succeeded, boundAddress: request.addressType)) + ), + promise: nil + ) + + do { + try context.channel.pipeline.syncOperations.addHandlers( + [ + ByteToMessageHandler(HTTPRequestDecoder()), + HTTPResponseEncoder(), + ], + position: .after(self) + ) + context.channel.pipeline.syncOperations.removeHandler(self, promise: nil) + context.channel.pipeline.syncOperations.removeHandler(self.handshakeHandler, promise: nil) + } catch { + context.fireErrorCaught(error) } } } @@ -134,7 +164,12 @@ class TestHTTPServer: ChannelInboundHandler { break case .end: context.write(self.wrapOutboundOut(.head(.init(version: .http1_1, status: .ok))), promise: nil) - context.write(self.wrapOutboundOut(.body(.byteBuffer(context.channel.allocator.buffer(string: self.expectedResponse)))), promise: nil) + context.write( + self.wrapOutboundOut( + .body(.byteBuffer(context.channel.allocator.buffer(string: self.expectedResponse))) + ), + promise: nil + ) context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) } } diff --git a/Tests/AsyncHTTPClientTests/SSLContextCacheTests+XCTest.swift b/Tests/AsyncHTTPClientTests/SSLContextCacheTests+XCTest.swift deleted file mode 100644 index d98f5a853..000000000 --- a/Tests/AsyncHTTPClientTests/SSLContextCacheTests+XCTest.swift +++ /dev/null @@ -1,33 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// SSLContextCacheTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension SSLContextCacheTests { - static var allTests: [(String, (SSLContextCacheTests) -> () throws -> Void)] { - return [ - ("testRequestingSSLContextWorks", testRequestingSSLContextWorks), - ("testCacheWorks", testCacheWorks), - ("testCacheDoesNotReturnWrongEntry", testCacheDoesNotReturnWrongEntry), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/SSLContextCacheTests.swift b/Tests/AsyncHTTPClientTests/SSLContextCacheTests.swift index 438c643d7..c7588cc7d 100644 --- a/Tests/AsyncHTTPClientTests/SSLContextCacheTests.swift +++ b/Tests/AsyncHTTPClientTests/SSLContextCacheTests.swift @@ -12,12 +12,13 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOPosix import NIOSSL import XCTest +@testable import AsyncHTTPClient + final class SSLContextCacheTests: XCTestCase { func testRequestingSSLContextWorks() { let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) @@ -27,9 +28,13 @@ final class SSLContextCacheTests: XCTestCase { XCTAssertNoThrow(try group.syncShutdownGracefully()) } - XCTAssertNoThrow(try cache.sslContext(tlsConfiguration: .makeClientConfiguration(), - eventLoop: eventLoop, - logger: HTTPClient.loggingDisabled).wait()) + XCTAssertNoThrow( + try cache.sslContext( + tlsConfiguration: .makeClientConfiguration(), + eventLoop: eventLoop, + logger: HTTPClient.loggingDisabled + ).wait() + ) } func testCacheWorks() { @@ -43,12 +48,20 @@ final class SSLContextCacheTests: XCTestCase { var firstContext: NIOSSLContext? var secondContext: NIOSSLContext? - XCTAssertNoThrow(firstContext = try cache.sslContext(tlsConfiguration: .makeClientConfiguration(), - eventLoop: eventLoop, - logger: HTTPClient.loggingDisabled).wait()) - XCTAssertNoThrow(secondContext = try cache.sslContext(tlsConfiguration: .makeClientConfiguration(), - eventLoop: eventLoop, - logger: HTTPClient.loggingDisabled).wait()) + XCTAssertNoThrow( + firstContext = try cache.sslContext( + tlsConfiguration: .makeClientConfiguration(), + eventLoop: eventLoop, + logger: HTTPClient.loggingDisabled + ).wait() + ) + XCTAssertNoThrow( + secondContext = try cache.sslContext( + tlsConfiguration: .makeClientConfiguration(), + eventLoop: eventLoop, + logger: HTTPClient.loggingDisabled + ).wait() + ) XCTAssertNotNil(firstContext) XCTAssertNotNil(secondContext) XCTAssert(firstContext === secondContext) @@ -65,16 +78,24 @@ final class SSLContextCacheTests: XCTestCase { var firstContext: NIOSSLContext? var secondContext: NIOSSLContext? - XCTAssertNoThrow(firstContext = try cache.sslContext(tlsConfiguration: .makeClientConfiguration(), - eventLoop: eventLoop, - logger: HTTPClient.loggingDisabled).wait()) + XCTAssertNoThrow( + firstContext = try cache.sslContext( + tlsConfiguration: .makeClientConfiguration(), + eventLoop: eventLoop, + logger: HTTPClient.loggingDisabled + ).wait() + ) // Second one has a _different_ TLSConfiguration. var testTLSConfig = TLSConfiguration.makeClientConfiguration() testTLSConfig.certificateVerification = .none - XCTAssertNoThrow(secondContext = try cache.sslContext(tlsConfiguration: testTLSConfig, - eventLoop: eventLoop, - logger: HTTPClient.loggingDisabled).wait()) + XCTAssertNoThrow( + secondContext = try cache.sslContext( + tlsConfiguration: testTLSConfig, + eventLoop: eventLoop, + logger: HTTPClient.loggingDisabled + ).wait() + ) XCTAssertNotNil(firstContext) XCTAssertNotNil(secondContext) XCTAssert(firstContext !== secondContext) diff --git a/Tests/AsyncHTTPClientTests/StressGetHttpsTests.swift b/Tests/AsyncHTTPClientTests/StressGetHttpsTests.swift new file mode 100644 index 000000000..587e6c64c --- /dev/null +++ b/Tests/AsyncHTTPClientTests/StressGetHttpsTests.swift @@ -0,0 +1,58 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2022 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import AsyncHTTPClient +import Atomics +import Logging +import NIOConcurrencyHelpers +import NIOCore +import NIOFoundationCompat +import NIOHTTP1 +import NIOHTTPCompression +import NIOPosix +import NIOSSL +import NIOTestUtils +import NIOTransportServices +import XCTest + +#if canImport(Network) +import Network +#endif + +final class StressGetHttpsTests: XCTestCaseHTTPClientTestsBaseClass { + func testStressGetHttps() throws { + let localHTTPBin = HTTPBin(.http1_1(ssl: true)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(certificateVerification: .none) + ) + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + XCTAssertNoThrow(try localHTTPBin.shutdown()) + } + + let eventLoop = localClient.eventLoopGroup.next() + let requestCount = 200 + var futureResults = [EventLoopFuture]() + for _ in 1...requestCount { + let req = try HTTPClient.Request( + url: "https://localhost:\(localHTTPBin.port)/get", + method: .GET, + headers: ["X-internal-delay": "100"] + ) + futureResults.append(localClient.execute(request: req)) + } + XCTAssertNoThrow(try EventLoopFuture.andAllSucceed(futureResults, on: eventLoop).wait()) + } +} diff --git a/Tests/AsyncHTTPClientTests/TLSEventsHandlerTests+XCTest.swift b/Tests/AsyncHTTPClientTests/TLSEventsHandlerTests+XCTest.swift deleted file mode 100644 index 062132f4e..000000000 --- a/Tests/AsyncHTTPClientTests/TLSEventsHandlerTests+XCTest.swift +++ /dev/null @@ -1,34 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// TLSEventsHandlerTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension TLSEventsHandlerTests { - static var allTests: [(String, (TLSEventsHandlerTests) -> () throws -> Void)] { - return [ - ("testHandlerHappyPath", testHandlerHappyPath), - ("testHandlerFailsFutureWhenRemovedWithoutEvent", testHandlerFailsFutureWhenRemovedWithoutEvent), - ("testHandlerFailsFutureWhenHandshakeFails", testHandlerFailsFutureWhenHandshakeFails), - ("testHandlerIgnoresShutdownCompletedEvent", testHandlerIgnoresShutdownCompletedEvent), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/TLSEventsHandlerTests.swift b/Tests/AsyncHTTPClientTests/TLSEventsHandlerTests.swift index c119c7e50..988ba6e3f 100644 --- a/Tests/AsyncHTTPClientTests/TLSEventsHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/TLSEventsHandlerTests.swift @@ -12,13 +12,14 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOEmbedded import NIOSSL import NIOTLS import XCTest +@testable import AsyncHTTPClient + class TLSEventsHandlerTests: XCTestCase { func testHandlerHappyPath() { let tlsEventsHandler = TLSEventsHandler(deadline: nil) @@ -38,7 +39,7 @@ class TLSEventsHandlerTests: XCTestCase { let embedded = EmbeddedChannel(handlers: [tlsEventsHandler]) XCTAssertNotNil(tlsEventsHandler.tlsEstablishedFuture) - XCTAssertNoThrow(try embedded.pipeline.removeHandler(tlsEventsHandler).wait()) + XCTAssertNoThrow(try embedded.pipeline.syncOperations.removeHandler(tlsEventsHandler).wait()) XCTAssertThrowsError(try XCTUnwrap(tlsEventsHandler.tlsEstablishedFuture).wait()) } diff --git a/Tests/AsyncHTTPClientTests/Transaction+StateMachineTests+XCTest.swift b/Tests/AsyncHTTPClientTests/Transaction+StateMachineTests+XCTest.swift deleted file mode 100644 index a46c7dfc0..000000000 --- a/Tests/AsyncHTTPClientTests/Transaction+StateMachineTests+XCTest.swift +++ /dev/null @@ -1,35 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// Transaction+StateMachineTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension Transaction_StateMachineTests { - static var allTests: [(String, (Transaction_StateMachineTests) -> () throws -> Void)] { - return [ - ("testRequestWasQueuedAfterWillExecuteRequestWasCalled", testRequestWasQueuedAfterWillExecuteRequestWasCalled), - ("testRequestBodyStreamWasPaused", testRequestBodyStreamWasPaused), - ("testQueuedRequestGetsRemovedWhenDeadlineExceeded", testQueuedRequestGetsRemovedWhenDeadlineExceeded), - ("testScheduledRequestGetsRemovedWhenDeadlineExceeded", testScheduledRequestGetsRemovedWhenDeadlineExceeded), - ("testRequestWithHeadReceivedGetNotCancelledWhenDeadlineExceeded", testRequestWithHeadReceivedGetNotCancelledWhenDeadlineExceeded), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/Transaction+StateMachineTests.swift b/Tests/AsyncHTTPClientTests/Transaction+StateMachineTests.swift index ff1972330..a631e9a93 100644 --- a/Tests/AsyncHTTPClientTests/Transaction+StateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/Transaction+StateMachineTests.swift @@ -12,16 +12,21 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOEmbedded import NIOHTTP1 import XCTest +@testable import AsyncHTTPClient + +struct NoOpAsyncSequenceProducerDelegate: NIOAsyncSequenceProducerDelegate { + func produceMore() {} + func didTerminate() {} +} + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) final class Transaction_StateMachineTests: XCTestCase { func testRequestWasQueuedAfterWillExecuteRequestWasCalled() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let eventLoop = EmbeddedEventLoop() XCTAsyncTest { func workaround(_ continuation: CheckedContinuation) { @@ -33,7 +38,10 @@ final class Transaction_StateMachineTests: XCTestCase { state.requestWasQueued(queuer) let failAction = state.fail(HTTPClientError.cancelled) - guard case .failResponseHead(_, let error, let scheduler, let rexecutor, let bodyStreamContinuation) = failAction else { + guard + case .failResponseHead(_, let error, let scheduler, let rexecutor, let bodyStreamContinuation) = + failAction + else { return XCTFail("Unexpected fail action: \(failAction)") } XCTAssertEqual(error as? HTTPClientError, .cancelled) @@ -46,12 +54,9 @@ final class Transaction_StateMachineTests: XCTestCase { await XCTAssertThrowsError(try await withCheckedThrowingContinuation(workaround)) } - #endif } func testRequestBodyStreamWasPaused() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let eventLoop = EmbeddedEventLoop() XCTAsyncTest { func workaround(_ continuation: CheckedContinuation) { @@ -69,12 +74,10 @@ final class Transaction_StateMachineTests: XCTestCase { await XCTAssertThrowsError(try await withCheckedThrowingContinuation(workaround)) } - #endif } func testQueuedRequestGetsRemovedWhenDeadlineExceeded() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } + struct MyError: Error, Equatable {} XCTAsyncTest { func workaround(_ continuation: CheckedContinuation) { var state = Transaction.StateMachine(continuation) @@ -82,23 +85,69 @@ final class Transaction_StateMachineTests: XCTestCase { state.requestWasQueued(queuer) - let failAction = state.deadlineExceeded() - guard case .cancel(let continuation, let scheduler, nil, nil) = failAction else { + let deadlineExceededAction = state.deadlineExceeded() + guard case .cancelSchedulerOnly(let scheduler) = deadlineExceededAction else { + return XCTFail("Unexpected fail action: \(deadlineExceededAction)") + } + XCTAssertIdentical(scheduler as? MockTaskQueuer, queuer) + + let failAction = state.fail(MyError()) + guard + case .failResponseHead(let continuation, let error, nil, nil, bodyStreamContinuation: nil) = + failAction + else { return XCTFail("Unexpected fail action: \(failAction)") } XCTAssertIdentical(scheduler as? MockTaskQueuer, queuer) - continuation.resume(throwing: HTTPClientError.deadlineExceeded) + continuation.resume(throwing: error) } - await XCTAssertThrowsError(try await withCheckedThrowingContinuation(workaround)) + await XCTAssertThrowsError(try await withCheckedThrowingContinuation(workaround)) { + XCTAssertEqualTypeAndValue($0, MyError()) + } + } + } + + func testDeadlineExceededAndFullyFailedRequestCanBeCanceledWithNoEffect() { + struct MyError: Error, Equatable {} + XCTAsyncTest { + func workaround(_ continuation: CheckedContinuation) { + var state = Transaction.StateMachine(continuation) + let queuer = MockTaskQueuer() + + state.requestWasQueued(queuer) + + let deadlineExceededAction = state.deadlineExceeded() + guard case .cancelSchedulerOnly(let scheduler) = deadlineExceededAction else { + return XCTFail("Unexpected fail action: \(deadlineExceededAction)") + } + XCTAssertIdentical(scheduler as? MockTaskQueuer, queuer) + + let failAction = state.fail(MyError()) + guard + case .failResponseHead(let continuation, let error, nil, nil, bodyStreamContinuation: nil) = + failAction + else { + return XCTFail("Unexpected fail action: \(failAction)") + } + XCTAssertIdentical(scheduler as? MockTaskQueuer, queuer) + + let secondFailAction = state.fail(HTTPClientError.cancelled) + guard case .none = secondFailAction else { + return XCTFail("Unexpected fail action: \(secondFailAction)") + } + + continuation.resume(throwing: error) + } + + await XCTAssertThrowsError(try await withCheckedThrowingContinuation(workaround)) { + XCTAssertEqualTypeAndValue($0, MyError()) + } } - #endif } func testScheduledRequestGetsRemovedWhenDeadlineExceeded() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let eventLoop = EmbeddedEventLoop() XCTAsyncTest { func workaround(_ continuation: CheckedContinuation) { @@ -120,12 +169,40 @@ final class Transaction_StateMachineTests: XCTestCase { await XCTAssertThrowsError(try await withCheckedThrowingContinuation(workaround)) } - #endif + } + + func testDeadlineExceededRaceWithRequestWillExecute() { + let eventLoop = EmbeddedEventLoop() + XCTAsyncTest { + func workaround(_ continuation: CheckedContinuation) { + var state = Transaction.StateMachine(continuation) + let expectedExecutor = MockRequestExecutor(eventLoop: eventLoop) + let queuer = MockTaskQueuer() + + state.requestWasQueued(queuer) + + let deadlineExceededAction = state.deadlineExceeded() + guard case .cancelSchedulerOnly(let scheduler) = deadlineExceededAction else { + return XCTFail("Unexpected fail action: \(deadlineExceededAction)") + } + XCTAssertIdentical(scheduler as? MockTaskQueuer, queuer) + + let failAction = state.willExecuteRequest(expectedExecutor) + guard case .cancelAndFail(let returnedExecutor, let continuation, with: let error) = failAction else { + return XCTFail("Unexpected fail action: \(failAction)") + } + XCTAssertIdentical(returnedExecutor as? MockRequestExecutor, expectedExecutor) + + continuation.resume(throwing: error) + } + + await XCTAssertThrowsError(try await withCheckedThrowingContinuation(workaround)) { + XCTAssertEqualTypeAndValue($0, HTTPClientError.deadlineExceeded) + } + } } func testRequestWithHeadReceivedGetNotCancelledWhenDeadlineExceeded() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let eventLoop = EmbeddedEventLoop() XCTAsyncTest { func workaround(_ continuation: CheckedContinuation) { @@ -136,8 +213,11 @@ final class Transaction_StateMachineTests: XCTestCase { XCTAssertEqual(state.willExecuteRequest(executor), .none) state.requestWasQueued(queuer) let head = HTTPResponseHead(version: .http1_1, status: .ok) - let receiveResponseHeadAction = state.receiveResponseHead(head) - guard case .succeedResponseHead(head, let continuation) = receiveResponseHeadAction else { + let receiveResponseHeadAction = state.receiveResponseHead( + head, + delegate: NoOpAsyncSequenceProducerDelegate() + ) + guard case .succeedResponseHead(_, let continuation) = receiveResponseHeadAction else { return XCTFail("Unexpected action: \(receiveResponseHeadAction)") } @@ -150,11 +230,9 @@ final class Transaction_StateMachineTests: XCTestCase { await XCTAssertThrowsError(try await withCheckedThrowingContinuation(workaround)) } - #endif } } -#if compiler(>=5.5.2) && canImport(_Concurrency) @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension Transaction.StateMachine.StartExecutionAction: Equatable { public static func == (lhs: Self, rhs: Self) -> Bool { @@ -193,7 +271,7 @@ extension Transaction.StateMachine.NextWriteAction: Equatable { public static func == (lhs: Self, rhs: Self) -> Bool { switch (lhs, rhs) { case (.writeAndWait(let lhsEx), .writeAndWait(let rhsEx)), - (.writeAndContinue(let lhsEx), .writeAndContinue(let rhsEx)): + (.writeAndContinue(let lhsEx), .writeAndContinue(let rhsEx)): if let lhsMock = lhsEx as? MockRequestExecutor, let rhsMock = rhsEx as? MockRequestExecutor { return lhsMock === rhsMock } @@ -205,4 +283,3 @@ extension Transaction.StateMachine.NextWriteAction: Equatable { } } } -#endif diff --git a/Tests/AsyncHTTPClientTests/TransactionTests+XCTest.swift b/Tests/AsyncHTTPClientTests/TransactionTests+XCTest.swift deleted file mode 100644 index 190260647..000000000 --- a/Tests/AsyncHTTPClientTests/TransactionTests+XCTest.swift +++ /dev/null @@ -1,39 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// TransactionTests+XCTest.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -extension TransactionTests { - static var allTests: [(String, (TransactionTests) -> () throws -> Void)] { - return [ - ("testCancelAsyncRequest", testCancelAsyncRequest), - ("testResponseStreamingWorks", testResponseStreamingWorks), - ("testIgnoringResponseBodyWorks", testIgnoringResponseBodyWorks), - ("testWriteBackpressureWorks", testWriteBackpressureWorks), - ("testSimpleGetRequest", testSimpleGetRequest), - ("testSimplePostRequest", testSimplePostRequest), - ("testPostStreamFails", testPostStreamFails), - ("testResponseStreamFails", testResponseStreamFails), - ("testBiDirectionalStreamingHTTP2", testBiDirectionalStreamingHTTP2), - ] - } -} diff --git a/Tests/AsyncHTTPClientTests/TransactionTests.swift b/Tests/AsyncHTTPClientTests/TransactionTests.swift index 7e2c62a0d..8e6464a5b 100644 --- a/Tests/AsyncHTTPClientTests/TransactionTests.swift +++ b/Tests/AsyncHTTPClientTests/TransactionTests.swift @@ -12,27 +12,27 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOConcurrencyHelpers import NIOCore import NIOEmbedded +import NIOFoundationCompat import NIOHTTP1 import NIOPosix import XCTest -#if compiler(>=5.5.2) && canImport(_Concurrency) +@testable import AsyncHTTPClient + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) typealias PreparedRequest = HTTPClientRequest.Prepared -#endif +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) final class TransactionTests: XCTestCase { func testCancelAsyncRequest() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { - let embeddedEventLoop = EmbeddedEventLoop() - defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let loop = NIOAsyncTestingEventLoop() + let scheduledRequestCanceled = loop.makePromise(of: Void.self) + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } var request = HTTPClientRequest(url: "https://localhost/") request.method = .GET @@ -41,34 +41,90 @@ final class TransactionTests: XCTestCase { guard let preparedRequest = maybePreparedRequest else { return XCTFail("Expected to have a request here.") } - let (transaction, responseTask) = Transaction.makeWithResultTask( + let (transaction, responseTask) = await Transaction.makeWithResultTask( request: preparedRequest, - preferredEventLoop: embeddedEventLoop + preferredEventLoop: loop ) - let queuer = MockTaskQueuer() + let queuer = MockTaskQueuer { _ in + scheduledRequestCanceled.succeed() + } transaction.requestWasQueued(queuer) + XCTAssertEqual(queuer.hitCancelCount, 0) Task.detached { try await Task.sleep(nanoseconds: 5 * 1000 * 1000) transaction.cancel() } - XCTAssertEqual(queuer.hitCancelCount, 0) - await XCTAssertThrowsError(try await responseTask.value) { - XCTAssertEqual($0 as? HTTPClientError, .cancelled) + await XCTAssertThrowsError(try await responseTask.value) { error in + XCTAssertTrue(error is CancellationError, "unexpected error \(error)") + } + + // self.fulfillment(of:) is not available on Linux + try await scheduledRequestCanceled.futureResult.timeout(after: .seconds(1)).get() + } + } + + func testDeadlineExceededWhileQueuedAndExecutorImmediatelyCancelsTask() { + XCTAsyncTest { + let loop = NIOAsyncTestingEventLoop() + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } + + var request = HTTPClientRequest(url: "https://localhost/") + request.method = .GET + var maybePreparedRequest: PreparedRequest? + XCTAssertNoThrow(maybePreparedRequest = try PreparedRequest(request)) + guard let preparedRequest = maybePreparedRequest else { + return XCTFail("Expected to have a request here.") + } + let (transaction, responseTask) = await Transaction.makeWithResultTask( + request: preparedRequest, + preferredEventLoop: loop + ) + + let queuer = MockTaskQueuer() + transaction.requestWasQueued(queuer) + + transaction.deadlineExceeded() + + struct Executor: HTTPRequestExecutor { + func writeRequestBodyPart( + _: NIOCore.IOData, + request: AsyncHTTPClient.HTTPExecutableRequest, + promise: NIOCore.EventLoopPromise? + ) { + XCTFail() + } + + func finishRequestBodyStream( + _ task: AsyncHTTPClient.HTTPExecutableRequest, + promise: NIOCore.EventLoopPromise? + ) { + XCTFail() + } + + func demandResponseBodyStream(_: AsyncHTTPClient.HTTPExecutableRequest) { + XCTFail() + } + + func cancelRequest(_ task: AsyncHTTPClient.HTTPExecutableRequest) { + task.fail(HTTPClientError.cancelled) + } + } + + transaction.willExecuteRequest(Executor()) + + await XCTAssertThrowsError(try await responseTask.value) { error in + XCTAssertEqualTypeAndValue(error, HTTPClientError.deadlineExceeded) } - XCTAssertEqual(queuer.hitCancelCount, 1) } - #endif } func testResponseStreamingWorks() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { - let embeddedEventLoop = EmbeddedEventLoop() - defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let loop = NIOAsyncTestingEventLoop() + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } var request = HTTPClientRequest(url: "https://localhost/") request.method = .GET @@ -78,14 +134,14 @@ final class TransactionTests: XCTestCase { guard let preparedRequest = maybePreparedRequest else { return } - let (transaction, responseTask) = Transaction.makeWithResultTask( + let (transaction, responseTask) = await Transaction.makeWithResultTask( request: preparedRequest, - preferredEventLoop: embeddedEventLoop + preferredEventLoop: loop ) let executor = MockRequestExecutor( pauseRequestBodyPartStreamAfterASingleWrite: true, - eventLoop: embeddedEventLoop + eventLoop: loop ) transaction.willExecuteRequest(executor) @@ -100,11 +156,11 @@ final class TransactionTests: XCTestCase { XCTAssertEqual(response.headers, responseHead.headers) XCTAssertEqual(response.version, responseHead.version) - let iterator = SharedIterator(response.body.filter { $0.readableBytes > 0 }.makeAsyncIterator()) + let iterator = SharedIterator(response.body.filter { $0.readableBytes > 0 }) - for i in 0..<100 { - XCTAssertFalse(executor.signalledDemandForResponseBody, "Demand was not signalled yet.") + XCTAssertFalse(executor.signalledDemandForResponseBody, "Demand was not signalled yet.") + for i in 0..<100 { async let part = iterator.next() XCTAssertNoThrow(try executor.receiveResponseDemand()) @@ -115,7 +171,6 @@ final class TransactionTests: XCTestCase { XCTAssertEqual(result, ByteBuffer(integer: i)) } - XCTAssertFalse(executor.signalledDemandForResponseBody, "Demand was not signalled yet.") async let part = iterator.next() XCTAssertNoThrow(try executor.receiveResponseDemand()) executor.resetResponseStreamDemandSignal() @@ -123,15 +178,12 @@ final class TransactionTests: XCTestCase { let result = try await part XCTAssertNil(result) } - #endif } func testIgnoringResponseBodyWorks() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { - let embeddedEventLoop = EmbeddedEventLoop() - defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let loop = NIOAsyncTestingEventLoop() + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } var request = HTTPClientRequest(url: "https://localhost/") request.method = .GET @@ -141,9 +193,9 @@ final class TransactionTests: XCTestCase { guard let preparedRequest = maybePreparedRequest else { return } - var tuple: (Transaction, Task)! = Transaction.makeWithResultTask( + var tuple: (Transaction, Task)! = await Transaction.makeWithResultTask( request: preparedRequest, - preferredEventLoop: embeddedEventLoop + preferredEventLoop: loop ) let transaction = tuple.0 @@ -152,9 +204,10 @@ final class TransactionTests: XCTestCase { let executor = MockRequestExecutor( pauseRequestBodyPartStreamAfterASingleWrite: true, - eventLoop: embeddedEventLoop + eventLoop: loop ) executor.runRequest(transaction) + await loop.run() let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["foo": "bar"]) XCTAssertFalse(executor.signalledDemandForResponseBody) @@ -174,15 +227,12 @@ final class TransactionTests: XCTestCase { transaction.receiveResponseBodyParts([ByteBuffer(string: "foo bar")]) transaction.succeedRequest(nil) } - #endif } func testWriteBackpressureWorks() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { - let embeddedEventLoop = EmbeddedEventLoop() - defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let loop = NIOAsyncTestingEventLoop() + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } let streamWriter = AsyncSequenceWriter() XCTAssertFalse(streamWriter.hasDemand, "Did not expect to have a demand at this point") @@ -196,28 +246,31 @@ final class TransactionTests: XCTestCase { guard let preparedRequest = maybePreparedRequest else { return XCTFail("Expected to have a request here.") } - let (transaction, responseTask) = Transaction.makeWithResultTask( + let (transaction, responseTask) = await Transaction.makeWithResultTask( request: preparedRequest, - preferredEventLoop: embeddedEventLoop + preferredEventLoop: loop ) - let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) + let executor = MockRequestExecutor(eventLoop: loop) executor.runRequest(transaction) + await loop.run() for i in 0..<100 { XCTAssertFalse(streamWriter.hasDemand, "Did not expect to have demand yet") transaction.resumeRequestBodyStream() - await streamWriter.demand() // wait's for the stream writer to signal demand + await streamWriter.demand() // wait's for the stream writer to signal demand transaction.pauseRequestBodyStream() let part = ByteBuffer(integer: i) streamWriter.write(part) - XCTAssertNoThrow(try executor.receiveRequestBody { - XCTAssertEqual($0, part) - }) + XCTAssertNoThrow( + try executor.receiveRequestBody { + XCTAssertEqual($0, part) + } + ) } transaction.resumeRequestBodyStream() @@ -237,7 +290,7 @@ final class TransactionTests: XCTestCase { XCTAssertEqual(response.headers, responseHead.headers) XCTAssertEqual(response.version, responseHead.version) - let iterator = SharedIterator(response.body.makeAsyncIterator()) + let iterator = SharedIterator(response.body) XCTAssertFalse(executor.signalledDemandForResponseBody, "Demand was not signalled yet.") async let part = iterator.next() @@ -248,12 +301,9 @@ final class TransactionTests: XCTestCase { let result = try await part XCTAssertNil(result) } - #endif } func testSimpleGetRequest() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) let eventLoop = eventLoopGroup.next() @@ -264,12 +314,14 @@ final class TransactionTests: XCTestCase { let connectionCreator = TestConnectionCreator() let delegate = TestHTTP2ConnectionDelegate() - var maybeHTTP2Connection: HTTP2Connection? - XCTAssertNoThrow(maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( - to: httpBin.port, - delegate: delegate, - on: eventLoop - )) + var maybeHTTP2Connection: HTTP2Connection.SendableView? + XCTAssertNoThrow( + maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( + to: httpBin.port, + delegate: delegate, + on: eventLoop + ) + ) guard let http2Connection = maybeHTTP2Connection else { return XCTFail("Expected to have an HTTP2 connection here.") } @@ -282,7 +334,7 @@ final class TransactionTests: XCTestCase { guard let preparedRequest = maybePreparedRequest else { return XCTFail("Expected to have a request here.") } - let (transaction, responseTask) = Transaction.makeWithResultTask( + let (transaction, responseTask) = await Transaction.makeWithResultTask( request: preparedRequest, preferredEventLoop: eventLoopGroup.next() ) @@ -306,15 +358,12 @@ final class TransactionTests: XCTestCase { RequestInfo(data: "", requestNumber: 1, connectionNumber: 0) ) } - #endif } func testSimplePostRequest() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { - let embeddedEventLoop = EmbeddedEventLoop() - defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let loop = NIOAsyncTestingEventLoop() + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } var request = HTTPClientRequest(url: "https://localhost/") request.method = .POST @@ -324,17 +373,20 @@ final class TransactionTests: XCTestCase { guard let preparedRequest = maybePreparedRequest else { return XCTFail("Expected to have a request here.") } - let (transaction, responseTask) = Transaction.makeWithResultTask( + let (transaction, responseTask) = await Transaction.makeWithResultTask( request: preparedRequest, - preferredEventLoop: embeddedEventLoop + preferredEventLoop: loop ) - let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) + let executor = MockRequestExecutor(eventLoop: loop) executor.runRequest(transaction) + await loop.run() executor.resumeRequestBodyStream() - XCTAssertNoThrow(try executor.receiveRequestBody { - XCTAssertEqual($0.getString(at: 0, length: $0.readableBytes), "Hello world!") - }) + XCTAssertNoThrow( + try executor.receiveRequestBody { + XCTAssertEqual($0.getString(at: 0, length: $0.readableBytes), "Hello world!") + } + ) XCTAssertNoThrow(try executor.receiveEndOfStream()) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["foo": "bar"]) @@ -346,15 +398,12 @@ final class TransactionTests: XCTestCase { XCTAssertEqual(response.version, .http1_1) XCTAssertEqual(response.headers, ["foo": "bar"]) } - #endif } func testPostStreamFails() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { - let embeddedEventLoop = EmbeddedEventLoop() - defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let loop = NIOAsyncTestingEventLoop() + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } let writer = AsyncSequenceWriter() @@ -366,21 +415,24 @@ final class TransactionTests: XCTestCase { guard let preparedRequest = maybePreparedRequest else { return XCTFail("Expected to have a request here.") } - let (transaction, responseTask) = Transaction.makeWithResultTask( + let (transaction, responseTask) = await Transaction.makeWithResultTask( request: preparedRequest, - preferredEventLoop: embeddedEventLoop + preferredEventLoop: loop ) - let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) + let executor = MockRequestExecutor(eventLoop: loop) executor.runRequest(transaction) + await loop.run() executor.resumeRequestBodyStream() await writer.demand() writer.write(.init(string: "Hello world!")) - XCTAssertNoThrow(try executor.receiveRequestBody { - XCTAssertEqual($0.getString(at: 0, length: $0.readableBytes), "Hello world!") - }) + XCTAssertNoThrow( + try executor.receiveRequestBody { + XCTAssertEqual($0.getString(at: 0, length: $0.readableBytes), "Hello world!") + } + ) XCTAssertFalse(executor.isCancelled) struct WriteError: Error, Equatable {} @@ -391,15 +443,12 @@ final class TransactionTests: XCTestCase { } XCTAssertNoThrow(try executor.receiveCancellation()) } - #endif } func testResponseStreamFails() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } - XCTAsyncTest { - let embeddedEventLoop = EmbeddedEventLoop() - defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + XCTAsyncTest(timeout: 30) { + let loop = NIOAsyncTestingEventLoop() + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } var request = HTTPClientRequest(url: "https://localhost/") request.method = .GET @@ -409,14 +458,14 @@ final class TransactionTests: XCTestCase { guard let preparedRequest = maybePreparedRequest else { return } - let (transaction, responseTask) = Transaction.makeWithResultTask( + let (transaction, responseTask) = await Transaction.makeWithResultTask( request: preparedRequest, - preferredEventLoop: embeddedEventLoop + preferredEventLoop: loop ) let executor = MockRequestExecutor( pauseRequestBodyPartStreamAfterASingleWrite: true, - eventLoop: embeddedEventLoop + eventLoop: loop ) transaction.willExecuteRequest(executor) @@ -427,17 +476,19 @@ final class TransactionTests: XCTestCase { transaction.receiveResponseHead(responseHead) let response = try await responseTask.value + XCTAssertEqual(response.status, responseHead.status) XCTAssertEqual(response.headers, responseHead.headers) XCTAssertEqual(response.version, responseHead.version) XCTAssertFalse(executor.signalledDemandForResponseBody, "Demand was not signalled yet.") - let iterator = SharedIterator(response.body.filter { $0.readableBytes > 0 }.makeAsyncIterator()) + let iterator = SharedIterator(response.body.filter { $0.readableBytes > 0 }) async let part1 = iterator.next() XCTAssertNoThrow(try executor.receiveResponseDemand()) executor.resetResponseStreamDemandSignal() transaction.receiveResponseBodyParts([ByteBuffer(integer: 123)]) + let result = try await part1 XCTAssertEqual(result, ByteBuffer(integer: 123)) @@ -454,12 +505,9 @@ final class TransactionTests: XCTestCase { XCTAssertEqual($0 as? HTTPClientError, .readTimeout) } } - #endif } func testBiDirectionalStreamingHTTP2() { - #if compiler(>=5.5.2) && canImport(_Concurrency) - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } XCTAsyncTest { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) let eventLoop = eventLoopGroup.next() @@ -470,12 +518,14 @@ final class TransactionTests: XCTestCase { let connectionCreator = TestConnectionCreator() let delegate = TestHTTP2ConnectionDelegate() - var maybeHTTP2Connection: HTTP2Connection? - XCTAssertNoThrow(maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( - to: httpBin.port, - delegate: delegate, - on: eventLoop - )) + var maybeHTTP2Connection: HTTP2Connection.SendableView? + XCTAssertNoThrow( + maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( + to: httpBin.port, + delegate: delegate, + on: eventLoop + ) + ) guard let http2Connection = maybeHTTP2Connection else { return XCTFail("Expected to have an HTTP2 connection here.") } @@ -486,14 +536,14 @@ final class TransactionTests: XCTestCase { var request = HTTPClientRequest(url: "https://localhost:\(httpBin.port)/") request.method = .POST request.headers = ["host": "localhost:\(httpBin.port)"] - request.body = .stream(streamWriter, length: .known(800)) + request.body = .stream(streamWriter, length: .known(Int64(800))) var maybePreparedRequest: PreparedRequest? XCTAssertNoThrow(maybePreparedRequest = try PreparedRequest(request)) guard let preparedRequest = maybePreparedRequest else { return } - let (transaction, responseTask) = Transaction.makeWithResultTask( + let (transaction, responseTask) = await Transaction.makeWithResultTask( request: preparedRequest, preferredEventLoop: eventLoopGroup.next() ) @@ -508,7 +558,7 @@ final class TransactionTests: XCTestCase { XCTAssertEqual(response.version, .http2) XCTAssertEqual(delegate.hitStreamClosed, 0) - let iterator = SharedIterator(response.body.filter { $0.readableBytes > 0 }.makeAsyncIterator()) + let iterator = SharedIterator(response.body.filter { $0.readableBytes > 0 }) // at this point we can start to write to the stream and wait for the results @@ -529,42 +579,95 @@ final class TransactionTests: XCTestCase { XCTAssertNil(final) XCTAssertEqual(delegate.hitStreamClosed, 1) } - #endif } } -#if compiler(>=5.5.2) && canImport(_Concurrency) - // This needs a small explanation. If an iterator is a struct, it can't be used across multiple // tasks. Since we want to wait for things to happen in tests, we need to `async let`, which creates // implicit tasks. Therefore we need to wrap our iterator struct. @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -actor SharedIterator { - private var iterator: Iterator +final class SharedIterator: Sendable where Wrapped.Element: Sendable { + private struct State: @unchecked Sendable { + var wrappedIterator: Wrapped.AsyncIterator + var nextCallInProgress: Bool = false + } - init(_ iterator: Iterator) { - self.iterator = iterator + private let state: NIOLockedValueBox + + init(_ sequence: Wrapped) { + self.state = NIOLockedValueBox(State(wrappedIterator: sequence.makeAsyncIterator())) } - func next() async throws -> Iterator.Element? { - var iter = self.iterator - defer { self.iterator = iter } + func next() async throws -> Wrapped.Element? { + var iter = self.state.withLockedValue { + precondition($0.nextCallInProgress == false) + $0.nextCallInProgress = true + return $0.wrappedIterator + } + + defer { + self.state.withLockedValue { + precondition($0.nextCallInProgress == true) + $0.nextCallInProgress = false + $0.wrappedIterator = iter + } + } return try await iter.next() } } +/// non fail-able promise that only supports one observer +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +private actor Promise { + private enum State { + case initialised + case fulfilled(Value) + } + + private var state: State = .initialised + + private var observer: CheckedContinuation? + + init() {} + + func fulfil(_ value: Value) { + switch self.state { + case .initialised: + self.state = .fulfilled(value) + self.observer?.resume(returning: value) + case .fulfilled: + preconditionFailure("\(Self.self) over fulfilled") + } + } + + var value: Value { + get async { + switch self.state { + case .initialised: + return await withCheckedContinuation { (continuation: CheckedContinuation) in + precondition(self.observer == nil, "\(Self.self) supports only one observer") + self.observer = continuation + } + case .fulfilled(let value): + return value + } + } + } +} + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension Transaction { fileprivate static func makeWithResultTask( - request: PreparedRequest, + request: sending PreparedRequest, requestOptions: RequestOptions = .forTests(), logger: Logger = Logger(label: "test"), connectionDeadline: NIODeadline = .distantFuture, preferredEventLoop: EventLoop - ) -> (Transaction, _Concurrency.Task) { - let transactionPromise = preferredEventLoop.makePromise(of: Transaction.self) - let result = Task { - try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + ) async -> (Transaction, _Concurrency.Task) { + let transactionPromise = Promise() + let task = Task { + try await withCheckedThrowingContinuation { + (continuation: CheckedContinuation) in let transaction = Transaction( request: request, requestOptions: requestOptions, @@ -573,13 +676,12 @@ extension Transaction { preferredEventLoop: preferredEventLoop, responseContinuation: continuation ) - transactionPromise.succeed(transaction) + Task { + await transactionPromise.fulfil(transaction) + } } } - // the promise can never fail and it is therefore safe to force unwrap - let transaction = try! transactionPromise.futureResult.wait() - return (transaction, result) + return (await transactionPromise.value, task) } } -#endif diff --git a/Tests/AsyncHTTPClientTests/XCTest+AsyncAwait.swift b/Tests/AsyncHTTPClientTests/XCTest+AsyncAwait.swift index fbc429b10..6cdcf4f8a 100644 --- a/Tests/AsyncHTTPClientTests/XCTest+AsyncAwait.swift +++ b/Tests/AsyncHTTPClientTests/XCTest+AsyncAwait.swift @@ -11,26 +11,25 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -/* - * Copyright 2021, gRPC Authors All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#if compiler(>=5.5.2) && canImport(_Concurrency) +// +// Copyright 2021, gRPC Authors All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + import XCTest extension XCTestCase { - @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) /// Cross-platform XCTest support for async-await tests. /// /// Currently the Linux implementation of XCTest doesn't have async-await support. @@ -39,6 +38,7 @@ extension XCTestCase { /// /// - NOTE: Support for Linux is tracked by https://bugs.swift.org/browse/SR-14403. /// - NOTE: Implementation currently in progress: https://github.com/apple/swift-corelibs-xctest/pull/326 + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) func XCTAsyncTest( expectationDescription: String = "Async operation", timeout: TimeInterval = 30, @@ -53,7 +53,7 @@ extension XCTestCase { try await operation() } catch { XCTFail("Error thrown while executing \(function): \(error)", file: file, line: line) - Thread.callStackSymbols.forEach { print($0) } + for symbol in Thread.callStackSymbols { print(symbol) } } expectation.fulfill() } @@ -65,7 +65,7 @@ extension XCTestCase { internal func XCTAssertThrowsError( _ expression: @autoclosure () async throws -> T, verify: (Error) -> Void = { _ in }, - file: StaticString = #file, + file: StaticString = #filePath, line: UInt = #line ) async { do { @@ -79,7 +79,7 @@ internal func XCTAssertThrowsError( @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) internal func XCTAssertNoThrowWithResult( _ expression: @autoclosure () async throws -> Result, - file: StaticString = #file, + file: StaticString = #filePath, line: UInt = #line ) async -> Result? { do { @@ -89,5 +89,3 @@ internal func XCTAssertNoThrowWithResult( } return nil } - -#endif diff --git a/Tests/LinuxMain.swift b/Tests/LinuxMain.swift deleted file mode 100644 index cebced614..000000000 --- a/Tests/LinuxMain.swift +++ /dev/null @@ -1,64 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// LinuxMain.swift -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - -#if os(Linux) || os(FreeBSD) -@testable import AsyncHTTPClientTests - -XCTMain([ - testCase(AsyncAwaitEndToEndTests.allTests), - testCase(HTTP1ClientChannelHandlerTests.allTests), - testCase(HTTP1ConnectionStateMachineTests.allTests), - testCase(HTTP1ConnectionTests.allTests), - testCase(HTTP1ProxyConnectHandlerTests.allTests), - testCase(HTTP2ClientRequestHandlerTests.allTests), - testCase(HTTP2ClientTests.allTests), - testCase(HTTP2ConnectionTests.allTests), - testCase(HTTP2IdleHandlerTests.allTests), - testCase(HTTPClientCookieTests.allTests), - testCase(HTTPClientInternalTests.allTests), - testCase(HTTPClientNIOTSTests.allTests), - testCase(HTTPClientReproTests.allTests), - testCase(HTTPClientRequestTests.allTests), - testCase(HTTPClientSOCKSTests.allTests), - testCase(HTTPClientTests.allTests), - testCase(HTTPClientUncleanSSLConnectionShutdownTests.allTests), - testCase(HTTPConnectionPoolTests.allTests), - testCase(HTTPConnectionPool_FactoryTests.allTests), - testCase(HTTPConnectionPool_HTTP1ConnectionsTests.allTests), - testCase(HTTPConnectionPool_HTTP1StateMachineTests.allTests), - testCase(HTTPConnectionPool_HTTP2ConnectionsTests.allTests), - testCase(HTTPConnectionPool_HTTP2StateMachineTests.allTests), - testCase(HTTPConnectionPool_ManagerTests.allTests), - testCase(HTTPConnectionPool_RequestQueueTests.allTests), - testCase(HTTPRequestStateMachineTests.allTests), - testCase(LRUCacheTests.allTests), - testCase(RequestBagTests.allTests), - testCase(RequestValidationTests.allTests), - testCase(SOCKSEventsHandlerTests.allTests), - testCase(SSLContextCacheTests.allTests), - testCase(TLSEventsHandlerTests.allTests), - testCase(TransactionTests.allTests), - testCase(Transaction_StateMachineTests.allTests), -]) -#endif diff --git a/docker/Dockerfile b/docker/Dockerfile deleted file mode 100644 index 1cd4f2140..000000000 --- a/docker/Dockerfile +++ /dev/null @@ -1,34 +0,0 @@ -ARG swift_version=5.4 -ARG ubuntu_version=bionic -ARG base_image=swift:$swift_version-$ubuntu_version -FROM $base_image -# needed to do again after FROM due to docker limitation -ARG swift_version -ARG ubuntu_version - -# set as UTF-8 -RUN apt-get update && apt-get install -y locales locales-all -ENV LC_ALL en_US.UTF-8 -ENV LANG en_US.UTF-8 -ENV LANGUAGE en_US.UTF-8 - -# dependencies -RUN apt-get update && apt-get install -y wget -RUN apt-get update && apt-get install -y lsof dnsutils netcat-openbsd net-tools libz-dev curl jq # used by integration tests - -# ruby and jazzy for docs generation -RUN apt-get update && apt-get install -y ruby ruby-dev libsqlite3-dev build-essential -# jazzy no longer works on xenial as ruby is too old. -RUN if [ "${ubuntu_version}" = "focal" ] ; then echo "gem: --no-document" > ~/.gemrc; fi -RUN if [ "${ubuntu_version}" = "focal" ] ; then gem install jazzy; fi - -# tools -RUN mkdir -p $HOME/.tools -RUN echo 'export PATH="$HOME/.tools:$PATH"' >> $HOME/.profile - -# swiftformat (until part of the toolchain) - -ARG swiftformat_version=0.48.8 -RUN git clone --branch $swiftformat_version --depth 1 https://github.com/nicklockwood/SwiftFormat $HOME/.tools/swift-format -RUN cd $HOME/.tools/swift-format && swift build -c release -RUN ln -s $HOME/.tools/swift-format/.build/release/swiftformat $HOME/.tools/swiftformat diff --git a/docker/docker-compose.1804.54.yaml b/docker/docker-compose.1804.54.yaml deleted file mode 100644 index 660429851..000000000 --- a/docker/docker-compose.1804.54.yaml +++ /dev/null @@ -1,18 +0,0 @@ -version: "3" - -services: - - runtime-setup: - image: async-http-client:18.04-5.4 - build: - args: - ubuntu_version: "bionic" - swift_version: "5.4" - - test: - image: async-http-client:18.04-5.4 - environment: [] - #- SANITIZER_ARG=--sanitize=thread - - shell: - image: async-http-client:18.04-5.4 diff --git a/docker/docker-compose.2004.55.yaml b/docker/docker-compose.2004.55.yaml deleted file mode 100644 index 4d0a12ee7..000000000 --- a/docker/docker-compose.2004.55.yaml +++ /dev/null @@ -1,18 +0,0 @@ -version: "3" - -services: - - runtime-setup: - image: async-http-client:20.04-5.5 - build: - args: - ubuntu_version: "focal" - swift_version: "5.5" - - test: - image: async-http-client:20.04-5.5 - environment: [] - #- SANITIZER_ARG=--sanitize=thread - - shell: - image: async-http-client:20.04-5.5 diff --git a/docker/docker-compose.2004.56.yaml b/docker/docker-compose.2004.56.yaml deleted file mode 100644 index ed61267a9..000000000 --- a/docker/docker-compose.2004.56.yaml +++ /dev/null @@ -1,18 +0,0 @@ -version: "3" - -services: - - runtime-setup: - image: async-http-client:20.04-5.6 - build: - args: - ubuntu_version: "focal" - swift_version: "5.6" - - test: - image: async-http-client:20.04-5.6 - environment: [] - #- SANITIZER_ARG=--sanitize=thread - - shell: - image: async-http-client:20.04-5.6 diff --git a/docker/docker-compose.2004.57.yaml b/docker/docker-compose.2004.57.yaml deleted file mode 100644 index 16c564482..000000000 --- a/docker/docker-compose.2004.57.yaml +++ /dev/null @@ -1,17 +0,0 @@ -version: "3" - -services: - - runtime-setup: - image: async-http-client:20.04-5.7 - build: - args: - base_image: "swiftlang/swift:nightly-main-focal" - - test: - image: async-http-client:20.04-5.7 - environment: [] - #- SANITIZER_ARG=--sanitize=thread - - shell: - image: async-http-client:20.04-5.7 diff --git a/docker/docker-compose.2004.main.yaml b/docker/docker-compose.2004.main.yaml deleted file mode 100644 index 11c7517ba..000000000 --- a/docker/docker-compose.2004.main.yaml +++ /dev/null @@ -1,17 +0,0 @@ -version: "3" - -services: - - runtime-setup: - image: async-http-client:20.04-main - build: - args: - base_image: "swiftlang/swift:nightly-main-focal" - - test: - image: async-http-client:20.04-main - environment: [] - #- SANITIZER_ARG=--sanitize=thread - - shell: - image: async-http-client:20.04-main diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml deleted file mode 100644 index 6269e953b..000000000 --- a/docker/docker-compose.yaml +++ /dev/null @@ -1,41 +0,0 @@ -# this file is not designed to be run directly -# instead, use the docker-compose.. files -# eg docker-compose -f docker/docker-compose.yaml -f docker/docker-compose.1804.50.yaml run test -version: "3" - -services: - - runtime-setup: - image: async-http-client:default - build: - context: . - dockerfile: Dockerfile - - common: &common - image: async-http-client:default - depends_on: [runtime-setup] - volumes: - - ~/.ssh:/root/.ssh - - ..:/code:z - working_dir: /code - cap_drop: - - CAP_NET_RAW - - CAP_NET_BIND_SERVICE - - soundness: - <<: *common - command: /bin/bash -xcl "./scripts/soundness.sh" - - test: - <<: *common - command: /bin/bash -xcl "swift test --parallel -Xswiftc -warnings-as-errors $${SANITIZER_ARG-}" - - # util - - shell: - <<: *common - entrypoint: /bin/bash - - docs: - <<: *common - command: /bin/bash -cl "./scripts/generate_docs.sh" diff --git a/scripts/check_no_api_breakages.sh b/scripts/check_no_api_breakages.sh deleted file mode 100755 index 2d7028617..000000000 --- a/scripts/check_no_api_breakages.sh +++ /dev/null @@ -1,68 +0,0 @@ -#!/bin/bash -##===----------------------------------------------------------------------===## -## -## This source file is part of the AsyncHTTPClient open source project -## -## Copyright (c) 2018-2022 Apple Inc. and the AsyncHTTPClient project authors -## Licensed under Apache License v2.0 -## -## See LICENSE.txt for license information -## See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -## -## SPDX-License-Identifier: Apache-2.0 -## -##===----------------------------------------------------------------------===## - -##===----------------------------------------------------------------------===## -## -## This source file is part of the SwiftNIO open source project -## -## Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors -## Licensed under Apache License v2.0 -## -## See LICENSE.txt for license information -## See CONTRIBUTORS.txt for the list of SwiftNIO project authors -## -## SPDX-License-Identifier: Apache-2.0 -## -##===----------------------------------------------------------------------===## - -set -eu - -function usage() { - echo >&2 "Usage: $0 REPO-GITHUB-URL NEW-VERSION OLD-VERSIONS..." - echo >&2 - echo >&2 "This script requires a Swift 5.6+ toolchain." - echo >&2 - echo >&2 "Examples:" - echo >&2 - echo >&2 "Check between main and tag 1.9.0 of async-http-client:" - echo >&2 " $0 https://github.com/swift-server/async-http-client main 1.9.0" - echo >&2 - echo >&2 "Check between HEAD and commit 64cf63d7 using the provided toolchain:" - echo >&2 " xcrun --toolchain org.swift.5120190702a $0 ../some-local-repo HEAD 64cf63d7" -} - -if [[ $# -lt 3 ]]; then - usage - exit 1 -fi - -tmpdir=$(mktemp -d /tmp/.check-api_XXXXXX) -repo_url=$1 -new_tag=$2 -shift 2 - -repodir="$tmpdir/repo" -git clone "$repo_url" "$repodir" -git -C "$repodir" fetch -q origin '+refs/pull/*:refs/remotes/origin/pr/*' -cd "$repodir" -git checkout -q "$new_tag" - -for old_tag in "$@"; do - echo "Checking public API breakages from $old_tag to $new_tag" - - swift package diagnose-api-breaking-changes "$old_tag" -done - -echo done diff --git a/scripts/generate_contributors_list.sh b/scripts/generate_contributors_list.sh deleted file mode 100755 index 00c162638..000000000 --- a/scripts/generate_contributors_list.sh +++ /dev/null @@ -1,39 +0,0 @@ -#!/bin/bash -##===----------------------------------------------------------------------===## -## -## This source file is part of the AsyncHTTPClient open source project -## -## Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -## Licensed under Apache License v2.0 -## -## See LICENSE.txt for license information -## See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -## -## SPDX-License-Identifier: Apache-2.0 -## -##===----------------------------------------------------------------------===## - -set -eu -here="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" -contributors=$( cd "$here"/.. && git shortlog -es | cut -f2 | sed 's/^/- /' ) - -cat > "$here/../CONTRIBUTORS.txt" <<- EOF - For the purpose of tracking copyright, this is the list of individuals and - organizations who have contributed source code to the AsyncHTTPClient. - - For employees of an organization/company where the copyright of work done - by employees of that company is held by the company itself, only the company - needs to be listed here. - - ## COPYRIGHT HOLDERS - - - Apple Inc. (all contributors with '@apple.com') - - ### Contributors - - $contributors - - **Updating this list** - - Please do not edit this file manually. It is generated using \`./scripts/generate_contributors_list.sh\`. If a name is misspelled or appearing multiple times: add an entry in \`./.mailmap\` -EOF diff --git a/scripts/generate_docs.sh b/scripts/generate_docs.sh deleted file mode 100755 index 82da814d3..000000000 --- a/scripts/generate_docs.sh +++ /dev/null @@ -1,114 +0,0 @@ -#!/bin/bash -##===----------------------------------------------------------------------===## -## -## This source file is part of the AsyncHTTPClient open source project -## -## Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -## Licensed under Apache License v2.0 -## -## See LICENSE.txt for license information -## See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -## -## SPDX-License-Identifier: Apache-2.0 -## -##===----------------------------------------------------------------------===## - -set -e - -my_path="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" -root_path="$my_path/.." -version=$(git describe --abbrev=0 --tags || echo "main") -modules=(AsyncHTTPClient) - -if [[ "$(uname -s)" == "Linux" ]]; then - # build code if required - if [[ ! -d "$root_path/.build/x86_64-unknown-linux" ]]; then - swift build - fi - # setup source-kitten if required - mkdir -p "$root_path/.build/sourcekitten" - source_kitten_source_path="$root_path/.build/sourcekitten/source" - if [[ ! -d "$source_kitten_source_path" ]]; then - git clone https://github.com/jpsim/SourceKitten.git "$source_kitten_source_path" - fi - source_kitten_path="$source_kitten_source_path/.build/debug" - if [[ ! -d "$source_kitten_path" ]]; then - rm -rf "$source_kitten_source_path/.swift-version" - cd "$source_kitten_source_path" && swift build && cd "$root_path" - fi - # generate - for module in "${modules[@]}"; do - if [[ ! -f "$root_path/.build/sourcekitten/$module.json" ]]; then - "$source_kitten_path/sourcekitten" doc --spm --module-name $module > "$root_path/.build/sourcekitten/$module.json" - fi - done -fi - -[[ -d docs/$version ]] || mkdir -p docs/$version -[[ -d async-http-client.xcodeproj ]] || swift package generate-xcodeproj - -# run jazzy -if ! command -v jazzy > /dev/null; then - gem install jazzy --no-ri --no-rdoc -fi - -jazzy_dir="$root_path/.build/jazzy" -rm -rf "$jazzy_dir" -mkdir -p "$jazzy_dir" - -module_switcher="$jazzy_dir/README.md" -jazzy_args=(--clean - --author 'AsyncHTTPClient team' - --readme "$module_switcher" - --author_url https://github.com/swift-server/async-http-client - --github_url https://github.com/swift-server/async-http-client - --github-file-prefix "https://github.com/swift-server/async-http-client/tree/$version" - --theme fullwidth - --xcodebuild-arguments -scheme,async-http-client-Package) -cat > "$module_switcher" <<"EOF" -# AsyncHTTPClient Docs - -AsyncHTTPClient is a Swift HTTP Client package. - -To get started with AsyncHTTPClient, [`import AsyncHTTPClient`](../AsyncHTTPClient/index.html). The -most important type is [`HTTPClient`](https://swift-server.github.io/async-http-client/docs/current/AsyncHTTPClient/Classes/HTTPClient.html) -which you can use to emit log messages. - -EOF - -tmp=`mktemp -d` -for module in "${modules[@]}"; do - args=("${jazzy_args[@]}" --output "$jazzy_dir/docs/$version/$module" --docset-path "$jazzy_dir/docset/$version/$module" - --module "$module" --module-version $version - --root-url "https://swift-server.github.io/async-http-client/docs/$version/$module/") - if [[ -f "$root_path/.build/sourcekitten/$module.json" ]]; then - args+=(--sourcekitten-sourcefile "$root_path/.build/sourcekitten/$module.json") - fi - jazzy "${args[@]}" -done - -# push to github pages -if [[ $PUSH == true ]]; then - BRANCH_NAME=$(git rev-parse --abbrev-ref HEAD) - GIT_AUTHOR=$(git --no-pager show -s --format='%an <%ae>' HEAD) - git fetch origin +gh-pages:gh-pages - git checkout gh-pages - rm -rf "docs/$version" - rm -rf "docs/current" - cp -r "$jazzy_dir/docs/$version" docs/ - cp -r "docs/$version" docs/current - git add --all docs - echo '' > index.html - git add index.html - touch .nojekyll - git add .nojekyll - changes=$(git diff-index --name-only HEAD) - if [[ -n "$changes" ]]; then - echo -e "changes detected\n$changes" - git commit --author="$GIT_AUTHOR" -m "publish $version docs" - git push origin gh-pages - else - echo "no changes detected" - fi - git checkout -f $BRANCH_NAME -fi diff --git a/scripts/generate_linux_tests.rb b/scripts/generate_linux_tests.rb deleted file mode 100755 index ed887f83c..000000000 --- a/scripts/generate_linux_tests.rb +++ /dev/null @@ -1,231 +0,0 @@ -#!/usr/bin/env ruby - -# -# process_test_files.rb -# -# Copyright 2016 Tony Stone -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Created by Tony Stone on 5/4/16. -# -require 'getoptlong' -require 'fileutils' -require 'pathname' - -include FileUtils - -# -# This ruby script will auto generate LinuxMain.swift and the +XCTest.swift extension files for Swift Package Manager on Linux platforms. -# -# See https://github.com/apple/swift-corelibs-xctest/blob/master/Documentation/Linux.md -# -def header(fileName) - string = <<-eos -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -// -// -// -import XCTest - -/// -/// NOTE: This file was generated by generate_linux_tests.rb -/// -/// Do NOT edit this file directly as it will be regenerated automatically when needed. -/// - eos - - string - .sub('', File.basename(fileName)) - .sub('', Time.now.to_s) -end - -def createExtensionFile(fileName, classes) - extensionFile = fileName.sub! '.swift', '+XCTest.swift' - print 'Creating file: ' + extensionFile + "\n" - - File.open(extensionFile, 'w') do |file| - file.write header(extensionFile) - file.write "\n" - - for classArray in classes - file.write 'extension ' + classArray[0] + " {\n" - file.write ' static var allTests: [(String, (' + classArray[0] + ") -> () throws -> Void)] {\n" - file.write " return [\n" - - for funcName in classArray[1] - file.write ' ("' + funcName + '", ' + funcName + "),\n" - end - - file.write " ]\n" - file.write " }\n" - file.write "}\n" - end - end -end - -def createLinuxMain(testsDirectory, allTestSubDirectories, files) - fileName = testsDirectory + '/LinuxMain.swift' - print 'Creating file: ' + fileName + "\n" - - File.open(fileName, 'w') do |file| - file.write header(fileName) - file.write "\n" - - file.write "#if os(Linux) || os(FreeBSD)\n" - for testSubDirectory in allTestSubDirectories.sort { |x, y| x <=> y } - file.write '@testable import ' + testSubDirectory + "\n" - end - file.write "\n" - file.write "XCTMain([\n" - - testCases = [] - for classes in files - for classArray in classes - testCases << classArray[0] - end - end - - for testCase in testCases.sort { |x, y| x <=> y } - file.write ' testCase(' + testCase + ".allTests),\n" - end - file.write "])\n" - file.write "#endif\n" - end -end - -def parseSourceFile(fileName) - puts 'Parsing file: ' + fileName + "\n" - - classes = [] - currentClass = nil - inIfLinux = false - inElse = false - ignore = false - - # - # Read the file line by line - # and parse to find the class - # names and func names - # - File.readlines(fileName).each do |line| - if inIfLinux - if /\#else/.match(line) - inElse = true - ignore = true - else - if /\#end/.match(line) - inElse = false - inIfLinux = false - ignore = false - end - end - else - if /\#if[ \t]+os\(Linux\)/.match(line) - inIfLinux = true - ignore = false - end - end - - next if ignore - # Match class or func - match = line[/class[ \t]+[a-zA-Z0-9_]*(?=[ \t]*:[ \t]*XCTestCase)|func[ \t]+test[a-zA-Z0-9_]*(?=[ \t]*\(\))/, 0] - if match - - if match[/class/, 0] == 'class' - className = match.sub(/^class[ \t]+/, '') - # - # Create a new class / func structure - # and add it to the classes array. - # - currentClass = [className, []] - classes << currentClass - else # Must be a func - funcName = match.sub(/^func[ \t]+/, '') - # - # Add each func name the the class / func - # structure created above. - # - currentClass[1] << funcName - end - end - end - classes -end - -# -# Main routine -# -# - -testsDirectory = 'Tests' - -options = GetoptLong.new(['--tests-dir', GetoptLong::OPTIONAL_ARGUMENT]) -options.quiet = true - -begin - options.each do |option, value| - case option - when '--tests-dir' - testsDirectory = value - end - end -rescue GetoptLong::InvalidOption -end - -allTestSubDirectories = [] -allFiles = [] - -Dir[testsDirectory + '/*'].each do |subDirectory| - next unless File.directory?(subDirectory) - directoryHasClasses = false - Dir[subDirectory + '/*Test{s,}.swift'].each do |fileName| - next unless File.file? fileName - fileClasses = parseSourceFile(fileName) - - # - # If there are classes in the - # test source file, create an extension - # file for it. - # - next unless fileClasses.count > 0 - createExtensionFile(fileName, fileClasses) - directoryHasClasses = true - allFiles << fileClasses - end - - if directoryHasClasses - allTestSubDirectories << Pathname.new(subDirectory).split.last.to_s - end -end - -# -# Last step is the create a LinuxMain.swift file that -# references all the classes and funcs in the source files. -# -if allFiles.count > 0 - createLinuxMain(testsDirectory, allTestSubDirectories, allFiles) -end -# eof diff --git a/scripts/soundness.sh b/scripts/soundness.sh deleted file mode 100755 index da9a91d24..000000000 --- a/scripts/soundness.sh +++ /dev/null @@ -1,164 +0,0 @@ -#!/bin/bash -##===----------------------------------------------------------------------===## -## -## This source file is part of the AsyncHTTPClient open source project -## -## Copyright (c) 2018-2022 Apple Inc. and the AsyncHTTPClient project authors -## Licensed under Apache License v2.0 -## -## See LICENSE.txt for license information -## See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -## -## SPDX-License-Identifier: Apache-2.0 -## -##===----------------------------------------------------------------------===## - -set -eu -here="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" - -function replace_acceptable_years() { - # this needs to replace all acceptable forms with 'YEARS' - sed -e 's/20[12][0-9]-20[12][0-9]/YEARS/' -e 's/2019/YEARS/' -e 's/2020/YEARS/' -e 's/2021/YEARS/' -e 's/2022/YEARS/' -} - -printf "=> Checking linux tests... " -FIRST_OUT="$(git status --porcelain)" -ruby "$here/../scripts/generate_linux_tests.rb" > /dev/null -SECOND_OUT="$(git status --porcelain)" -if [[ "$FIRST_OUT" != "$SECOND_OUT" ]]; then - printf "\033[0;31mmissing changes!\033[0m\n" - git --no-pager diff - exit 1 -else - printf "\033[0;32mokay.\033[0m\n" -fi - -printf "=> Checking for unacceptable language... " -# This greps for unacceptable terminology. The square bracket[s] are so that -# "git grep" doesn't find the lines that greps :). -unacceptable_terms=( - -e blacklis[t] - -e whitelis[t] - -e slav[e] - -e sanit[y] -) -if git grep --color=never -i "${unacceptable_terms[@]}" > /dev/null; then - printf "\033[0;31mUnacceptable language found.\033[0m\n" - git grep -i "${unacceptable_terms[@]}" - exit 1 -fi -printf "\033[0;32mokay.\033[0m\n" - -printf "=> Checking format... " -FIRST_OUT="$(git status --porcelain)" -swiftformat . > /dev/null 2>&1 -SECOND_OUT="$(git status --porcelain)" -if [[ "$FIRST_OUT" != "$SECOND_OUT" ]]; then - printf "\033[0;31mformatting issues!\033[0m\n" - git --no-pager diff - exit 1 -else - printf "\033[0;32mokay.\033[0m\n" -fi - -printf "=> Checking license headers\n" -tmp=$(mktemp /tmp/.async-http-client-soundness_XXXXXX) - -for language in swift-or-c bash dtrace; do - printf " * $language... " - declare -a matching_files - declare -a exceptions - expections=( ) - matching_files=( -name '*' ) - case "$language" in - swift-or-c) - exceptions=( -name c_nio_http_parser.c -o -name c_nio_http_parser.h -o -name cpp_magic.h -o -name Package.swift -o -name CNIOSHA1.h -o -name c_nio_sha1.c -o -name ifaddrs-android.c -o -name ifaddrs-android.h) - matching_files=( -name '*.swift' -o -name '*.c' -o -name '*.h' ) - cat > "$tmp" <<"EOF" -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) YEARS Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -EOF - ;; - bash) - matching_files=( -name '*.sh' ) - cat > "$tmp" <<"EOF" -#!/bin/bash -##===----------------------------------------------------------------------===## -## -## This source file is part of the AsyncHTTPClient open source project -## -## Copyright (c) YEARS Apple Inc. and the AsyncHTTPClient project authors -## Licensed under Apache License v2.0 -## -## See LICENSE.txt for license information -## See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -## -## SPDX-License-Identifier: Apache-2.0 -## -##===----------------------------------------------------------------------===## -EOF - ;; - dtrace) - matching_files=( -name '*.d' ) - cat > "$tmp" <<"EOF" -#!/usr/sbin/dtrace -q -s -/*===----------------------------------------------------------------------===* - * - * This source file is part of the AsyncHTTPClient open source project - * - * Copyright (c) YEARS Apple Inc. and the AsyncHTTPClient project authors - * Licensed under Apache License v2.0 - * - * See LICENSE.txt for license information - * See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors - * - * SPDX-License-Identifier: Apache-2.0 - * - *===----------------------------------------------------------------------===*/ -EOF - ;; - *) - echo >&2 "ERROR: unknown language '$language'" - ;; - esac - - expected_lines=$(cat "$tmp" | wc -l) - expected_sha=$(cat "$tmp" | shasum) - - ( - cd "$here/.." - find . \ - \( \! -path './.build/*' -a \ - \( "${matching_files[@]}" \) -a \ - \( \! \( "${exceptions[@]}" \) \) \) | while read line; do - if [[ "$(cat "$line" | replace_acceptable_years | head -n $expected_lines | shasum)" != "$expected_sha" ]]; then - printf "\033[0;31mmissing headers in file '$line'!\033[0m\n" - diff -u <(cat "$line" | replace_acceptable_years | head -n $expected_lines) "$tmp" - exit 1 - fi - done - printf "\033[0;32mokay.\033[0m\n" - ) -done - -rm "$tmp" - -# This checks for the umbrella NIO module. -printf "=> Checking for imports of umbrella NIO module... " -if git grep --color=never -i "^[ \t]*import \+NIO[ \t]*$" > /dev/null; then - printf "\033[0;31mUmbrella imports found.\033[0m\n" - git grep -i "^[ \t]*import \+NIO[ \t]*$" - exit 1 -fi -printf "\033[0;32mokay.\033[0m\n"