diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index cace99d4..106db189 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -13,5 +13,8 @@ cxx_link(echo_server base fibers2 http_server_lib TRDP::gperf) add_executable(s3_demo s3_demo.cc) cxx_link(s3_demo base awsv2_lib) +add_executable(s3_old_demo s3_old_demo.cc) +cxx_link(s3_old_demo aws_lib) + add_executable(https_client_cli https_client_cli.cc) cxx_link(https_client_cli base fibers2 http_client_lib tls_lib) diff --git a/examples/s3_old_demo.cc b/examples/s3_old_demo.cc new file mode 100644 index 00000000..b9ccaaa5 --- /dev/null +++ b/examples/s3_old_demo.cc @@ -0,0 +1,172 @@ +// Copyright 2022, Roman Gershman. All rights reserved. +// See LICENSE for licensing terms. +// + +#include +#include +#include + +#include "base/init.h" +#include "util/cloud/aws.h" +#include "util/cloud/s3.h" +#include "util/cloud/s3_file.h" + +#include "util/fibers/pool.h" +#include "util/http/http_client.h" + +using namespace std; +using namespace util; +using cloud::AWS; + +ABSL_FLAG(string, cmd, "ls", ""); +ABSL_FLAG(string, region, "us-east-1", ""); +ABSL_FLAG(string, path, "", "s3://bucket/path"); +ABSL_FLAG(string, endpoint, "", "s3 endpoint"); +ABSL_FLAG(uint32_t, num_iters, 1, "Number of iterations"); +ABSL_FLAG(uint32_t, delay, 5, "Delay in seconds between each iteration"); +ABSL_FLAG(uint32_t, write_factor, 10000, "Number of 1K blocks to write"); + +namespace h2 = boost::beast::http; +using absl::GetFlag; + +#define CHECK_EC(x) \ + do { \ + auto __ec$ = (x); \ + CHECK(!__ec$) << "Error: " << __ec$ << " " << __ec$.message() << " for " << #x; \ + } while (false) + +template std::ostream& operator<<(std::ostream& os, const h2::request& msg) { + os << msg.method_string() << " " << msg.target() << endl; + for (const auto& f : msg) { + os << f.name_string() << " : " << f.value() << endl; + } + os << "-------------------------"; + + return os; +} + +void ListBuckets(AWS* aws, ProactorBase* proactor) { + string endpoint = GetFlag(FLAGS_endpoint); + if (endpoint.empty()) { + endpoint = "s3.amazonaws.com:80"; + } + + vector parts = absl::StrSplit(endpoint, ':'); + CHECK_EQ(parts.size(), 2u); + + http::Client http_client{proactor}; + + http_client.set_connect_timeout_ms(2000); + auto list_res = proactor->Await([&] { + CHECK_EC(http_client.Connect(parts[0], parts[1])); + return ListS3Buckets(aws, &http_client); + }); + + if (!list_res) { + cout << "Error: " << list_res.error() << endl; + return; + } + + for (const auto& b : *list_res) { + cout << b << endl; + } +}; + +int main(int argc, char* argv[]) { + MainInitGuard guard(&argc, &argv); + + unique_ptr pp; + pp.reset(fb2::Pool::IOUring(256)); + pp->Run(); + + AWS aws{"s3", GetFlag(FLAGS_region)}; + + pp->GetNextProactor()->Await([&] { CHECK_EC(aws.Init()); }); + + string cmd = GetFlag(FLAGS_cmd); + string path = GetFlag(FLAGS_path); + string endpoint = GetFlag(FLAGS_endpoint); + + if (path.empty()) { + CHECK(cmd == "ls"); + ListBuckets(&aws, pp->GetNextProactor()); + } else { + string_view clean = absl::StripPrefix(path, "s3://"); + string_view obj_path; + size_t pos = clean.find('/'); + string_view bucket_name = clean.substr(0, pos); + if (pos != string_view::npos) { + obj_path = clean.substr(pos + 1); + } + cloud::S3Bucket bucket = cloud::S3Bucket::FromEndpoint(aws, endpoint, bucket_name); + + if (cmd == "ls") { + cloud::S3Bucket::ListObjectCb cb = [](size_t sz, string_view name) { CONSOLE_INFO << name; }; + + error_code ec = pp->GetNextProactor()->Await([&] { + auto ec = bucket.Connect(300); + if (ec) + return ec; + unsigned num_iters = GetFlag(FLAGS_num_iters); + for (unsigned i = 0; i < num_iters; ++i) { + ec = bucket.ListAllObjects(obj_path, cb); + if (ec) + return ec; + + if (i + 1 < num_iters) + ThisFiber::SleepFor(chrono::seconds(GetFlag(FLAGS_delay))); + } + return ec; + }); + + CHECK(!ec) << ec; + } else if (cmd == "read") { + pp->GetNextProactor()->Await([&] { + auto ec = bucket.Connect(300); + CHECK(!ec); + + io::Result res = bucket.OpenReadFile(obj_path); + if (res) { + io::ReadonlyFile* file = *res; + std::unique_ptr buf(new uint8_t[1024]); + io::SizeOrError sz_res = file->Read(0, io::MutableBytes(buf.get(), 1024)); + if (sz_res) { + CONSOLE_INFO << "File contents(first 1024) of " << obj_path << ":"; + CONSOLE_INFO << string_view(reinterpret_cast(buf.get()), *sz_res); + } else { + LOG(ERROR) << "Error: " << sz_res.error(); + } + } else { + LOG(ERROR) << "Read Error: " << res.error().message(); + } + }); + } else if (cmd == "write") { + pp->GetNextProactor()->Await([&] { + auto ec = bucket.Connect(300); + CHECK(!ec); + + io::Result res = bucket.OpenWriteFile(obj_path); + if (res) { + unique_ptr file{*res}; + CHECK(file); + std::unique_ptr buf(new uint8_t[1024]); + memset(buf.get(), 'R', 1024); + for (size_t i = 0; i < GetFlag(FLAGS_write_factor); ++i) { + ec = file->Write(io::Bytes(buf.get(), 1024)); + CHECK(!ec); + } + ec = file->Close(); + CHECK(!ec); + } else { + LOG(ERROR) << "Error: " << res.error(); + } + }); + } else { + LOG(ERROR) << "Unknown command " << cmd; + } + } + + pp->Stop(); + + return 0; +} diff --git a/util/CMakeLists.txt b/util/CMakeLists.txt index 39babdc1..631d570f 100644 --- a/util/CMakeLists.txt +++ b/util/CMakeLists.txt @@ -7,6 +7,7 @@ add_subdirectory(html) add_subdirectory(metrics) add_subdirectory(tls) add_subdirectory(http) +add_subdirectory(cloud) if (WITH_AWS) add_subdirectory(aws) diff --git a/util/cloud/CMakeLists.txt b/util/cloud/CMakeLists.txt new file mode 100644 index 00000000..463745e5 --- /dev/null +++ b/util/cloud/CMakeLists.txt @@ -0,0 +1,6 @@ +find_package(LibXml2) + +add_library(aws_lib aws.cc s3.cc s3_file.cc) + +cxx_link(aws_lib base OpenSSL::Crypto TRDP::rapidjson http_utils + TRDP::pugixml http_client_lib) \ No newline at end of file diff --git a/util/cloud/aws.cc b/util/cloud/aws.cc new file mode 100644 index 00000000..1e2f5e1a --- /dev/null +++ b/util/cloud/aws.cc @@ -0,0 +1,642 @@ +// Copyright 2022, Roman Gershman. All rights reserved. +// See LICENSE for licensing terms. +// + +#include "util/cloud/aws.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "base/logging.h" +#include "io/file.h" +#include "io/line_reader.h" + +#include "util/fibers/proactor_base.h" + +#define RETURN_ON_ERR(x) \ + do { \ + auto __ec = (x); \ + if (__ec) { \ + DLOG(ERROR) << "Error " << __ec << " while calling " #x; \ + return __ec; \ + } \ + } while (0) + +namespace util { +namespace cloud { + +using namespace std; +namespace h2 = boost::beast::http; + +namespace { + +/*void EVPDigest(const ::boost::beast::multi_buffer& mb, unsigned char* md) { + EVP_MD_CTX* ctx = EVP_MD_CTX_new(); + CHECK_EQ(1, EVP_DigestInit_ex(ctx, EVP_sha256(), NULL)); + for (const auto& e : mb.cdata()) { + CHECK_EQ(1, EVP_DigestUpdate(ctx, e.data(), e.size())); + } + unsigned temp; + CHECK_EQ(1, EVP_DigestFinal_ex(ctx, md, &temp)); +}*/ + +void Hexify(const uint8_t* str, size_t len, char* dest) { + static constexpr char kHex[] = "0123456789abcdef"; + + for (unsigned i = 0; i < len; ++i) { + char c = str[i]; + *dest++ = kHex[(c >> 4) & 0xF]; + *dest++ = kHex[c & 0xF]; + } + *dest = '\0'; +} + +/*void Sha256String(const ::boost::beast::multi_buffer& mb, char out[65]) { + uint8_t hash[32]; + EVPDigest(mb, hash); + + Hexify(hash, sizeof(hash), out); +}*/ + +void Sha256String(string_view str, char out[65]) { + uint8_t hash[32]; + unsigned temp; + + CHECK_EQ(1, EVP_Digest(str.data(), str.size(), hash, &temp, EVP_sha256(), NULL)); + + Hexify(hash, sizeof(hash), out); +} + +void HMAC(absl::string_view key, absl::string_view msg, uint8_t dest[32]) { + // HMAC_xxx are deprecated since openssl 3.0 + // Ubuntu 20.04 uses openssl 1.1. + + unsigned len = 32; +#if 0 + HMAC_CTX* hmac = HMAC_CTX_new(); + + CHECK_EQ(1, HMAC_Init_ex(hmac, reinterpret_cast(key.data()), key.size(), + EVP_sha256(), NULL)); + + CHECK_EQ(1, HMAC_Update(hmac, reinterpret_cast(msg.data()), msg.size())); + + uint8_t* ptr = reinterpret_cast(dest); + + CHECK_EQ(1, HMAC_Final(hmac, ptr, &len)); + HMAC_CTX_free(hmac); +#else + const uint8_t* data = reinterpret_cast(msg.data()); + ::HMAC(EVP_sha256(), key.data(), key.size(), data, msg.size(), dest, &len); +#endif + CHECK_EQ(len, 32u); +} + +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" + +string DeriveSigKey(absl::string_view key, absl::string_view datestamp, absl::string_view region, + absl::string_view service) { + uint8_t sign[32]; + HMAC_CTX* hmac = HMAC_CTX_new(); + unsigned len; + + string start_key{"AWS4"}; + string_view sign_key{reinterpret_cast(sign), sizeof(sign)}; + + // TODO: to replace with EVP_MAC_CTX_new and EVP_MAC_CTX_free etc which appeared only + // in openssl 3.0. + absl::StrAppend(&start_key, key); + CHECK_EQ(1, HMAC_Init_ex(hmac, start_key.data(), start_key.size(), EVP_sha256(), NULL)); + CHECK_EQ(1, + HMAC_Update(hmac, reinterpret_cast(datestamp.data()), datestamp.size())); + CHECK_EQ(1, HMAC_Final(hmac, sign, &len)); + + CHECK_EQ(1, HMAC_Init_ex(hmac, sign_key.data(), sign_key.size(), EVP_sha256(), NULL)); + CHECK_EQ(1, HMAC_Update(hmac, reinterpret_cast(region.data()), region.size())); + CHECK_EQ(1, HMAC_Final(hmac, sign, &len)); + + CHECK_EQ(1, HMAC_Init_ex(hmac, sign_key.data(), sign_key.size(), EVP_sha256(), NULL)); + CHECK_EQ(1, HMAC_Update(hmac, reinterpret_cast(service.data()), service.size())); + CHECK_EQ(1, HMAC_Final(hmac, sign, &len)); + + const char* sr = "aws4_request"; + CHECK_EQ(1, HMAC_Init_ex(hmac, sign_key.data(), sign_key.size(), EVP_sha256(), NULL)); + CHECK_EQ(1, HMAC_Update(hmac, reinterpret_cast(sr), strlen(sr))); + CHECK_EQ(1, HMAC_Final(hmac, sign, &len)); + + return string(sign_key); +} + +inline std::string_view std_sv(const ::boost::beast::string_view s) { + return std::string_view{s.data(), s.size()}; +} + +constexpr char kAlgo[] = "AWS4-HMAC-SHA256"; + +// Try reading AwsConnectionData from env. +std::optional GetConnectionDataFromEnv() { + const char* access_key = getenv("AWS_ACCESS_KEY_ID"); + const char* secret_key = getenv("AWS_SECRET_ACCESS_KEY"); + const char* session_token = getenv("AWS_SESSION_TOKEN"); + const char* region = getenv("AWS_REGION"); + + if (access_key && secret_key) { + AwsConnectionData cd; + cd.access_key = access_key; + cd.secret_key = secret_key; + if (session_token) + cd.session_token = session_token; + if (region) + cd.region = region; + return cd; + } + + return std::nullopt; +} + +// Get path from ENV if env_var is set or default path relative to user home. +std::optional GetAlternativePath(std::string_view default_home_postfix, + const char* env_var) { + const char* path_override = getenv(env_var); + if (path_override) + return path_override; + + const char* home_folder = getenv("HOME"); + if (!home_folder) + return std::nullopt; + + return absl::StrCat(home_folder, default_home_postfix); +} + +std::optional ReadIniFile(std::string_view full_path) { + auto file = io::OpenRead(full_path, io::ReadonlyFile::Options{}); + if (!file) + return std::nullopt; + + io::FileSource file_source{file.value()}; + auto contents = ::io::ini::Parse(&file_source, Ownership::DO_NOT_TAKE_OWNERSHIP); + if (!contents) { + LOG(ERROR) << "Failed to parse ini file:" << full_path; + return std::nullopt; + } + + return contents.value(); +} + +// Try filling AwsConnectionData with data from config file. +void GetConfigFromFile(const char* profile, AwsConnectionData* cd) { + auto full_path = GetAlternativePath("/.aws/config", "AWS_CONFIG_FILE"); + if (!full_path) + return; + + auto contents = ReadIniFile(*full_path); + if (!contents) + return; + + auto it = contents->find(profile); + if (it != contents->end()) { + cd->region = it->second["region"]; + } +} + +// Try reading AwsConnectionData from credentials file. +std::optional GetConnectionDataFromFile() { + // Get credentials path. + auto full_path = GetAlternativePath("/.aws/credentials", "AWS_SHARED_CREDENTIALS_FILE"); + if (!full_path) + return std::nullopt; + + // Read credentials file. + auto contents = ReadIniFile(*full_path); + if (!contents) + return std::nullopt; + + // Read profile data. + const char* profile = getenv("AWS_PROFILE"); + if (profile == nullptr) + profile = "default"; + + auto it = contents->find(profile); + if (it != contents->end()) { + AwsConnectionData cd; + cd.access_key = it->second["aws_access_key_id"]; + cd.secret_key = it->second["aws_secret_access_key"]; + cd.session_token = it->second["aws_session_token"]; + GetConfigFromFile(profile, &cd); + return cd; + } + + if (profile != "default"sv) { + LOG(ERROR) << "Failed to find profile:" << profile << " in credentials"; + } + return std::nullopt; +} + +// Make simple GET request on path and return body. +std::optional MakeGetRequest(boost::string_view path, http::Client* http_client) { + h2::request req{h2::verb::get, path, 11}; + h2::response resp; + req.set(h2::field::host, http_client->host()); + + std::error_code ec = http_client->Send(req, &resp); + if (ec || resp.result() != h2::status::ok) + return std::nullopt; + + VLOG(1) << "Received response: " << resp; + if (resp[h2::field::connection] == "close") { + ec = http_client->Reconnect(); + if (ec) + return std::nullopt; + } + return resp.body(); +} + +void GetConfigFromMetadata(http::Client* http_client, AwsConnectionData* cd) { + const char* PATH = "/latest/dynamic/instance-identity/document"; + + auto resp = MakeGetRequest(PATH, http_client); + if (!resp) + return; + + rapidjson::Document doc; + doc.Parse(resp->c_str()); + if (doc.HasMember("region")) { + cd->region = doc["region"].GetString(); + } +} + +// Try getting AwsConnectionData from instance metadata. +std::optional GetConnectionDataFromMetadata( + std::string_view hinted_role_name = ""sv) { + ProactorBase* pb = ProactorBase::me(); + CHECK(pb); + + http::Client http_client{pb}; + error_code ec = http_client.Connect("169.254.169.254", "80"); + if (ec) + return std::nullopt; + + const char* PATH = "/latest/meta-data/iam/security-credentials/"; + + // Get role name if none provided. + std::string role_name{hinted_role_name}; + if (role_name.empty()) { + auto fetched_role = MakeGetRequest(PATH, &http_client); + if (!fetched_role) { + LOG(ERROR) << "Failed to get role name from metadata"; + return std::nullopt; + } + role_name = std::move(*fetched_role); + } + + // Get credentials. + std::string path = absl::StrCat(PATH, role_name); + auto resp = MakeGetRequest(path, &http_client); + if (!resp) + return std::nullopt; + VLOG(1) << "Received response: " << *resp; + + rapidjson::Document doc; + doc.Parse(resp->c_str()); + if (!doc.HasMember("AccessKeyId") || !doc.HasMember("SecretAccessKey")) + return std::nullopt; + + AwsConnectionData cd; + cd.access_key = doc["AccessKeyId"].GetString(); + cd.secret_key = doc["SecretAccessKey"].GetString(); + if (doc.HasMember("Token")) { + cd.session_token = doc["Token"].GetString(); + } + cd.role_name = role_name; + GetConfigFromMetadata(&http_client, &cd); + + return cd; +} + +AwsConnectionData GetConnectionData() { + std::optional keys; + + keys = GetConnectionDataFromEnv(); + if (keys) + return *keys; + + keys = GetConnectionDataFromFile(); + if (keys) + return *keys; + + keys = GetConnectionDataFromMetadata(); + if (keys) + return *keys; + + LOG(ERROR) << "Failed to find valid source for AWS connection data"; + return {}; +} + +void PopulateAwsConnectionData(const AwsConnectionData& src, AwsConnectionData* dest) { + // don't overwrite region as it can be provided as a flag. + std::string region = dest->region; + + *dest = src; + + if (!region.empty()) + dest->region = region; +} + +const char* kExpiredTokenSentinel = "ExpiredToken"; + +// Return true if the response indicates an expired token. +bool IsExpiredBody(string_view body) { + return body.find(kExpiredTokenSentinel) != std::string::npos; +} + +bool AwsIsEscaped(char c) { + return !((c >= 'A' && c <= 'Z') || + (c >= 'a' && c <= 'z') || + (c >= '0' && c <= '9') || + c == '-' || + c == '.' || + c == '_' || + c == '~' + ); +} + +// Escapes the given path as documented by +// https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-header-based-auth.html. +std::string AwsEscapePath(const std::string_view& path, bool encode_sep) { + std::string escaped; + for (char c : path) { + if (!AwsIsEscaped(c) || (c == '/' && !encode_sep)) { + escaped.push_back(c); + } else { + absl::StrAppendFormat(&escaped, "%%%02X", c); + } + } + return escaped; +} + +} // namespace + +const char AWS::kEmptySig[] = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"; +const char AWS::kUnsignedPayloadSig[] = "UNSIGNED-PAYLOAD"; + +AwsSignKey::AwsSignKey(std::string_view service, AwsConnectionData connection_data) + : service_(service), now_(0), connection_data_(std::move(connection_data)) { + RefreshIfNeeded(); +} + +void AwsSignKey::Sign(string_view payload_sig, HttpHeader* header) const { + const absl::TimeZone utc_tz = absl::UTCTimeZone(); + + RefreshIfNeeded(); + + // We show consider to pass it via argument to make the function test-friendly. + // Must be recent (upto 900sec skew is allowed vs amazon servers). + absl::Time tn = absl::Now(); + + string amz_date = absl::FormatTime("%Y%m%dT%H%M00Z", tn, utc_tz); + header->set("x-amz-date", amz_date); + + // older beast versions require passing beast::string_view + header->set("x-amz-content-sha256", + boost::beast::string_view{payload_sig.data(), payload_sig.size()}); + + const std::string& session_token = connection_data_.session_token; + if (!session_token.empty()) { + header->set("x-amz-security-token", session_token); + } + + /// The Canonical headers must include the following: + /// + /// - HTTP host header. + /// - If the Content-Type header is present in the request, you must add it as well + /// - Any x-amz-* headers that you plan to include in your request must also be added. + /// For example, if you are using temporary security credentials, you need + /// to include x-amz-security-token in your request and add it to the canonical header list. + + // TODO: right now I hardcoded the list but if we need more flexible headers, + // this code much change. + string canonical_headers = + absl::StrCat("host", ":", std_sv(header->find(h2::field::host)->value()), "\n"); + absl::StrAppend(&canonical_headers, "x-amz-content-sha256", ":", payload_sig, "\n"); + absl::StrAppend(&canonical_headers, "x-amz-date", ":", amz_date, "\n"); + if (!session_token.empty()) { + absl::StrAppend(&canonical_headers, "x-amz-security-token", ":", session_token, "\n"); + } + + SignHeaders sheaders; + sheaders.method = std_sv(header->method_string()); + sheaders.target = std_sv(header->target()); + sheaders.content_sha256 = payload_sig; + sheaders.amz_date = amz_date; + sheaders.headers = canonical_headers; + + string auth_header = AuthHeader(sheaders); + + header->set(h2::field::authorization, auth_header); +} + +void AwsSignKey::RefreshIfNeeded() const { + const absl::TimeZone utc_tz = absl::UTCTimeZone(); + + absl::CivilDay date_before(utc_tz.At(absl::FromTimeT(now_)).cs); + + // Must be recent (upto 900sec skew is allowed vs amazon servers). + time_t now = time(NULL); + absl::Time tn = absl::FromTimeT(now); + absl::CivilDay date_now(utc_tz.At(tn).cs); + + if (date_now == date_before) + return; + + now_ = now; + string amz_date = absl::FormatTime("%Y%m%d", tn, utc_tz); + DVLOG(1) << "Refreshing sign key to date " << amz_date; + + sign_key_ = + DeriveSigKey(connection_data_.secret_key, amz_date, connection_data_.region, service_); + + credential_scope_ = + absl::StrCat(amz_date, "/", connection_data_.region, "/", service_, "/", "aws4_request"); +} + +string AwsSignKey::AuthHeader(const SignHeaders& headers) const { + CHECK(!connection_data_.access_key.empty()); + + size_t pos = headers.target.find('?'); + string_view url = headers.target.substr(0, pos); + string_view query_string; + string canonical_querystring; + + if (pos != string::npos) { + query_string = headers.target.substr(pos + 1); + + // We must sign query string with params in alphabetical order + vector params = absl::StrSplit(query_string, "&", absl::SkipWhitespace{}); + sort(params.begin(), params.end()); + canonical_querystring = absl::StrJoin(params, "&"); + } + + string canonical_request = + absl::StrCat(headers.method, "\n", AwsEscapePath(url, false), "\n", canonical_querystring, "\n"); + string signed_headers = "host;x-amz-content-sha256;x-amz-date"; + if (!connection_data_.session_token.empty()) { + signed_headers += ";x-amz-security-token"; + } + + absl::StrAppend(&canonical_request, headers.headers, "\n", signed_headers, "\n", + headers.content_sha256); + VLOG(2) << "CanonicalRequest:\n" << canonical_request << "\n-------------------\n"; + + char hexdigest[65]; + Sha256String(canonical_request, hexdigest); + + string string_to_sign = + absl::StrCat(kAlgo, "\n", headers.amz_date, "\n", credential_scope_, "\n", hexdigest); + + uint8_t signature[32]; + HMAC(sign_key_, string_to_sign, signature); + Hexify(signature, sizeof(signature), hexdigest); + + string authorization_header = + absl::StrCat(kAlgo, " Credential=", connection_data_.access_key, "/", credential_scope_, + ",SignedHeaders=", signed_headers, ",Signature=", hexdigest); + + return authorization_header; +} + +bool AWS::RefreshToken() const { + unique_lock lk(mu_); + + if (!connection_data_.role_name.empty()) { + VLOG(1) << "Trying to update expired session token"; + + auto updated_data = GetConnectionDataFromMetadata(connection_data_.role_name); + if (!updated_data) + return false; + + PopulateAwsConnectionData(std::move(*updated_data), &connection_data_); + return true; + } + + return false; +} + +error_code AWS::Init() { + PopulateAwsConnectionData(GetConnectionData(), &connection_data_); + + if (connection_data_.access_key.empty()) { + LOG(WARNING) << "Can not find AWS_ACCESS_KEY_ID"; + return make_error_code(errc::operation_not_permitted); + } + + if (connection_data_.secret_key.empty()) { + LOG(WARNING) << "Can not find AWS_SECRET_ACCESS_KEY"; + return make_error_code(errc::operation_not_permitted); + } + + return error_code{}; +} + +AwsSignKey AWS::GetSignKey(string_view region) const { + unique_lock lk(mu_); + + DCHECK(!connection_data_.access_key.empty()) << "Init() was not called"; + + auto cd = connection_data_; + cd.region = region; + + return AwsSignKey(service_, std::move(cd)); +} + +error_code AWS::RetryExpired(http::Client* client, AwsSignKey* cached_key, EmptyBodyReq* req, + h2::response_header<>* header) const { + if (!RefreshToken()) { + LOG(ERROR) << "Could not refresh token"; + return make_error_code(errc::io_error); + } + + string region = cached_key->connection_data().region; + *cached_key = GetSignKey(region); + cached_key->Sign(AWS::kEmptySig, req); + return client->Send(*req); +} + +error_code AWS::SendRequest(http::Client* client, AwsSignKey* cached_key, + h2::request* req, + h2::response* resp) const { + cached_key->Sign(AWS::kEmptySig, req); + DVLOG(2) << "Signed request: " << *req; + + // We may experience connection disconnects. We catch the ECONNABORTED error during Recv step. + // TCP_KEEPALIVE settings do not seem to help. + for (unsigned i = 0; i < 2; ++i) { + RETURN_ON_ERR(client->Send(*req)); + + auto bec = client->Recv(resp); + if (!bec) // success + break; + if (bec.value() != boost::system::errc::connection_aborted) + return bec; + + RETURN_ON_ERR(client->Reconnect()); + } + + DVLOG(2) << "Received response: " << *resp; + auto& headers = resp->base(); + if (headers[h2::field::connection] == "close") { + RETURN_ON_ERR(client->Reconnect()); + } + + if (resp->result() == h2::status::bad_request && IsExpiredBody(resp->body())) { + RETURN_ON_ERR(RetryExpired(client, cached_key, req, &resp->base())); + + resp->clear(); + RETURN_ON_ERR(client->Recv(resp)); + } + return error_code{}; +} + +error_code AWS::Handshake(http::Client* client, AwsSignKey* cached_key, EmptyBodyReq* req, + HttpParser* parser) const { + cached_key->Sign(AWS::kEmptySig, req); + VLOG(1) << "Sending request: " << *req; + + RETURN_ON_ERR(client->Send(*req)); + + parser->body_limit(UINT64_MAX); + + RETURN_ON_ERR(client->ReadHeader(parser)); + + auto& msg = parser->get(); + + if (msg.result() == h2::status::bad_request) { + string str(512, '\0'); + msg.body().data = str.data(); + msg.body().size = str.size(); + RETURN_ON_ERR(client->Recv(parser)); + + if (IsExpiredBody(str)) { + RETURN_ON_ERR(RetryExpired(client, cached_key, req, &msg.base())); + + // TODO: seems we can not reuse the parser here. + // (*parser) = std::move(HttpParser{}); + parser->body_limit(UINT64_MAX); + RETURN_ON_ERR(client->ReadHeader(parser)); + } + } + + return error_code{}; +} + +} // namespace cloud +} // namespace util diff --git a/util/cloud/aws.h b/util/cloud/aws.h new file mode 100644 index 00000000..b89d4cf2 --- /dev/null +++ b/util/cloud/aws.h @@ -0,0 +1,105 @@ +// Copyright 2022, Roman Gershman. All rights reserved. +// See LICENSE for licensing terms. +// +#pragma once + +#include +#include +#include +#include +#include + +#include "util/fibers/synchronization.h" +#include "util/http/http_client.h" + +namespace util { +namespace cloud { + +namespace h2 = boost::beast::http; + +struct AwsConnectionData { + std::string access_key, secret_key, session_token; + std::string region, role_name; +}; + +class AwsSignKey { + public: + using HttpHeader = ::boost::beast::http::header; + + AwsSignKey() = default; + + AwsSignKey(std::string_view service, AwsConnectionData connection_data); + + void Sign(std::string_view payload_sig, HttpHeader* header) const; + + const AwsConnectionData& connection_data() const { + return connection_data_; + } + + time_t now() const { + return now_; + } + + private: + void RefreshIfNeeded() const; + + struct SignHeaders { + std::string_view method, headers, target; + std::string_view content_sha256, amz_date; + }; + + std::string AuthHeader(const SignHeaders& headers) const; + + std::string service_; + mutable std::string sign_key_, credential_scope_; + mutable time_t now_; // epoch time. + + AwsConnectionData connection_data_; +}; + +class AWS { + public: + static const char kEmptySig[]; + static const char kUnsignedPayloadSig[]; + using HttpParser = ::boost::beast::http::response_parser<::boost::beast::http::buffer_body>; + using EmptyBodyReq = ::boost::beast::http::request<::boost::beast::http::empty_body>; + + AWS(const std::string& service, const std::string& region = "") : service_(service) { + connection_data_.region = region; + } + + // Init must be run in a proactor thread. + std::error_code Init(); + + const AwsConnectionData& connection_data() const { + return connection_data_; + } + + // Returns true if succeeded to refresh the metadata. + // Thread-safe. + bool RefreshToken() const; + + AwsSignKey GetSignKey(std::string_view region) const; + + std::error_code SendRequest( + http::Client* client, AwsSignKey* cached_key, EmptyBodyReq* req, + ::boost::beast::http::response<::boost::beast::http::string_body>* resp) const; + + // Sends a request and reads back header response. Handles the response according to the header. + // The caller is responsible to read the rest of the response via parser. + std::error_code Handshake(http::Client* client, AwsSignKey* cached_key, EmptyBodyReq* req, + HttpParser* resp) const; + + private: + std::error_code RetryExpired(http::Client* client, AwsSignKey* cached_key, EmptyBodyReq* req, + ::boost::beast::http::response_header<>* header) const; + + std::string service_; + mutable AwsConnectionData connection_data_; + + mutable fb2::Mutex mu_; +}; + +} // namespace cloud + +} // namespace util diff --git a/util/cloud/s3.cc b/util/cloud/s3.cc new file mode 100644 index 00000000..372ef3b9 --- /dev/null +++ b/util/cloud/s3.cc @@ -0,0 +1,368 @@ +// Copyright 2022, Roman Gershman. All rights reserved. +// See LICENSE for licensing terms. +// + +#include "util/cloud/s3.h" + +#include +#include +#include + +#include +#include +#include + +#include "base/logging.h" +#include "util/cloud/aws.h" +#include "util/cloud/s3_file.h" +#include "util/fibers/proactor_base.h" +#include "util/http/encoding.h" + +namespace util { +namespace cloud { + +using namespace std; +namespace h2 = boost::beast::http; +using nonstd::make_unexpected; + +// Max number of keys in AWS response. +const unsigned kAwsMaxKeys = 1000; + +inline std::string_view std_sv(const ::boost::beast::string_view s) { + return std::string_view{s.data(), s.size()}; +} + +bool IsAwsEndpoint(string_view endpoint) { + return absl::EndsWith(endpoint, ".amazonaws.com"); +} + +namespace xml { + +ListBucketsResult ParseXmlListBuckets(string_view xml_resp) { + pugi::xml_document doc; + pugi::xml_parse_result result = doc.load_buffer(xml_resp.data(), xml_resp.size()); + if (!result) { + LOG(ERROR) << "Could not parse xml response " << result.description(); + return make_unexpected(make_error_code(errc::bad_message)); + } + + pugi::xml_node root = doc.child("ListAllMyBucketsResult"); + if (root.type() != pugi::node_element) { + LOG(ERROR) << "Could not find root node " << xml_resp; + return make_unexpected(make_error_code(errc::bad_message)); + } + + pugi::xml_node buckets = root.child("Buckets"); + if (buckets.type() != pugi::node_element) { + LOG(ERROR) << "Could not find buckets node " << xml_resp; + return make_unexpected(make_error_code(errc::bad_message)); + } + + vector res; + for (pugi::xml_node bucket = buckets.child("Bucket"); bucket; bucket = bucket.next_sibling()) { + res.emplace_back(bucket.child("Name").text().get()); + } + + return res; +} + +ListObjectsResult ParseListObj(string_view xml_resp, S3Bucket::ListObjectCb cb) { + pugi::xml_document doc; + pugi::xml_parse_result result = doc.load_buffer(xml_resp.data(), xml_resp.size()); + if (!result) { + LOG(ERROR) << "Could not parse xml response " << result.description(); + return make_unexpected(make_error_code(errc::bad_message)); + } + + string_view last_key; + + pugi::xml_node root = doc.child("ListBucketResult"); + if (root.type() != pugi::node_element) { + LOG(ERROR) << "Could not find root node " << xml_resp; + return make_unexpected(make_error_code(errc::bad_message)); + } + + // text() provides a convenient interface to avoid checking for potentially missing + // fields and rely on the defaults. + bool truncated = root.child("IsTruncated").text().as_bool(); + for (pugi::xml_node contents = root.child("Contents"); contents; + contents = contents.next_sibling("Contents")) { + size_t sz = contents.child("Size").text().as_ullong(); + string_view key = contents.child("Key").text().get(); + if (!key.empty()) + cb(sz, key); + last_key = key; + } + + return truncated ? std::string{last_key} : ""; +} + +} // namespace xml + +ListBucketsResult ListS3Buckets(AWS* aws, http::Client* http_client) { + h2::request req{h2::verb::get, "/", 11}; + req.set(h2::field::host, http_client->host()); + h2::response resp; + + AwsSignKey skey = aws->GetSignKey(aws->connection_data().region); + auto ec = aws->SendRequest(http_client, &skey, &req, &resp); + if (ec) { + return make_unexpected(ec); + } + + if (resp.result() != h2::status::ok) { + LOG(ERROR) << "http error: " << resp; + return make_unexpected(make_error_code(errc::inappropriate_io_control_operation)); + } + + VLOG(1) << "ListS3Buckets: " << resp; + + return xml::ParseXmlListBuckets(resp.body()); +} + +S3Bucket::S3Bucket(const AWS& aws, string_view bucket, string_view region) + : aws_(aws), bucket_(bucket), region_(region) { + CHECK(!bucket.empty()); + + if (region.empty()) { + region = aws_.connection_data().region; + if (region.empty()) { + region = "us-east-1"; + } + } + skey_ = aws.GetSignKey(region); +} + +S3Bucket S3Bucket::FromEndpoint(const AWS& aws, string_view endpoint, string_view bucket) { + S3Bucket res(aws, bucket); + res.endpoint_ = endpoint; + return res; +} + +std::error_code S3Bucket::Connect(uint32_t ms) { + ProactorBase* pb = ProactorBase::me(); + CHECK(pb); + + http_client_.reset(new http::Client{pb}); + http_client_->AssignOnConnect([](int fd) { + int val = 1; + if (setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, &val, sizeof(val)) < 0) + return; + + val = 20; + if (setsockopt(fd, IPPROTO_TCP, TCP_KEEPINTVL, &val, sizeof(val)) < 0) + return; + + val = 60; +#ifdef __APPLE__ + if (setsockopt(fd, IPPROTO_TCP, TCP_KEEPALIVE, &val, sizeof(val)) < 0) + return; +#else + if (setsockopt(fd, IPPROTO_TCP, TCP_KEEPIDLE, &val, sizeof(val)) < 0) + return; +#endif + + val = 3; + if (setsockopt(fd, IPPROTO_TCP, TCP_KEEPCNT, &val, sizeof(val)) < 0) + return; + }); + + http_client_->set_connect_timeout_ms(ms); + + return ConnectInternal(); +} + +ListObjectsResult S3Bucket::ListObjects(string_view bucket_path, ListObjectCb cb, + std::string_view marker, unsigned max_keys) { + CHECK_LE(max_keys, kAwsMaxKeys); + + string host = http_client_->host(); + std::string path; + + // Build full request path. + if (IsAwsEndpoint(host)) { + path.append("/?"); + } else { + path.append("/").append(bucket_).append("?"); + } + + if (bucket_path != "") + path += absl::StrCat("prefix=", util::http::UrlEncode(bucket_path), "&"); + + if (marker != "") + path += absl::StrCat("marker=", util::http::UrlEncode(marker), "&"); + + if (max_keys != kAwsMaxKeys) + path += absl::StrCat("max-keys=", max_keys, "&"); + + CHECK(path.back() == '?' || path.back() == '&'); + path.pop_back(); + + // Send request. + h2::request req{h2::verb::get, path, 11}; + req.set(h2::field::host, host); + + h2::response resp; + + error_code ec = aws_.SendRequest(http_client_.get(), &skey_, &req, &resp); + if (ec) + return make_unexpected(ec); + + if (resp.result() != h2::status::ok) { + LOG(ERROR) << "http error: " << resp; + return make_unexpected(make_error_code(errc::connection_refused)); + } + + if (!absl::StartsWith(resp.body(), " S3Bucket::OpenReadFile(string_view path, + const io::ReadonlyFile::Options& opts) { + string host = http_client_->host(); + string full_path{path}; + if (IsAwsEndpoint(host)) { + } else { + full_path = absl::StrCat(bucket_, "/", path); + } + + unique_ptr http_client = std::move(http_client_); + error_code ec = Connect(http_client->connect_timeout_ms()); + if (ec) + return make_unexpected(ec); + + return OpenS3ReadFile(region_, full_path, aws_, std::move(http_client), opts); +} + +io::Result S3Bucket::OpenWriteFile(std::string_view path) { + string host = http_client_->host(); + string full_path{path}; + if (IsAwsEndpoint(host)) { + } else { + full_path = absl::StrCat(bucket_, "/", path); + } + + unique_ptr http_client = std::move(http_client_); + error_code ec = Connect(http_client->connect_timeout_ms()); + if (ec) + return make_unexpected(ec); + + return OpenS3WriteFile(region_, full_path, aws_, std::move(http_client)); +} + +string S3Bucket::GetHost() const { + if (!endpoint_.empty()) + return endpoint_; + + // fallback to default aws endpoint. + if (region_.empty()) + return "s3.amazonaws.com"; + return absl::StrCat("s3.", region_, ".amazonaws.com"); +} + +error_code S3Bucket::ConnectInternal() { + string host = GetHost(); + auto pos = host.rfind(':'); + string port; + + if (pos != string::npos) { + port = host.substr(pos + 1); + host = host.substr(0, pos); + } else { + port = "80"; + } + + bool is_aws = IsAwsEndpoint(host); + if (is_aws) + host = absl::StrCat(bucket_, ".", host); + + VLOG(1) << "Connecting to " << host << ":" << port; + auto ec = http_client_->Connect(host, port); + if (ec) + return ec; + + if (region_.empty()) { + ec = DeriveRegion(); + } + + return ec; +} + +error_code S3Bucket::DeriveRegion() { + h2::request req(h2::verb::head, "/", 11); + req.set(h2::field::host, http_client_->host()); + bool is_aws = IsAwsEndpoint(http_client_->host()); + + h2::response_parser parser; + + if (is_aws) { + parser.skip(true); // for HEAD requests we do not get the body. + } else { + string url = absl::StrCat("/", bucket_, "?location="); + req.target(url); + + // TODO: can we keep HEAD for other providers? + req.method(h2::verb::get); + } + + skey_.Sign(AWS::kEmptySig, &req); + error_code ec = http_client_->Send(req); + if (ec) + return ec; + + ec = http_client_->ReadHeader(&parser); + if (ec) + return ec; + + h2::header& header = parser.get(); + + // I deliberately do not check for http status. AWS can return 400 or 403 and it still reports + // the region. + VLOG(1) << "LocationResp: " << header; + auto src = header["x-amz-bucket-region"]; + if (src.empty()) { + LOG(ERROR) << "x-amz-bucket-region is absent in response: " << header; + return make_error_code(errc::bad_message); + } + + region_ = std::string(src); + skey_ = aws_.GetSignKey(region_); + if (header[h2::field::connection] == "close") { + ec = http_client_->Reconnect(); + if (ec) + return ec; + } else if (!parser.is_done()) { + // Drain http response. + ec = http_client_->Recv(&parser); + } + + return ec; +} + +} // namespace cloud +} // namespace util diff --git a/util/cloud/s3.h b/util/cloud/s3.h new file mode 100644 index 00000000..47eac9b6 --- /dev/null +++ b/util/cloud/s3.h @@ -0,0 +1,69 @@ +// Copyright 2022, Roman Gershman. All rights reserved. +// See LICENSE for licensing terms. +// +#pragma once + +#include + +#include "io/file.h" +#include "io/io.h" +#include "util/cloud/aws.h" +#include "util/http/http_client.h" + +namespace util { +namespace cloud { + +using ListBucketsResult = io::Result>; + +// Inner result value is 'marker' to start next page from. +// Empty if no more pages left. +using ListObjectsResult = io::Result; + +// List all S3 buckets. Refresh AWS token if needed. +ListBucketsResult ListS3Buckets(AWS* aws, http::Client* http_client); + +class S3Bucket { + public: + S3Bucket(const S3Bucket&) = delete; + S3Bucket(S3Bucket&&) = default; + + S3Bucket(const AWS& aws, std::string_view bucket, std::string_view region = std::string_view{}); + + static S3Bucket FromEndpoint(const AWS& aws, std::string_view endpoint, std::string_view bucket); + + std::error_code Connect(uint32_t ms); + + //! Called with (size, key_name) pairs. + using ListObjectCb = std::function; + + // Iterate over bucket objects for given path, starting from a marker (default none). + // Up to max_keys entries are returned, possible maximum is 1000. + // Returns key to start next query from is result is truncated. + ListObjectsResult ListObjects(std::string_view path, ListObjectCb cb, + std::string_view marker = "", unsigned max_keys = 1000); + + // Iterate over all bucket objects for the given path. + std::error_code ListAllObjects(std::string_view path, ListObjectCb cb); + + io::Result OpenReadFile( + std::string_view path, const io::ReadonlyFile::Options& opts = io::ReadonlyFile::Options{}); + + io::Result OpenWriteFile(std::string_view path); + + private: + std::string GetHost() const; + std::error_code ConnectInternal(); + std::error_code DeriveRegion(); + + const AWS& aws_; + + std::string endpoint_; + std::string bucket_; + std::string region_; + AwsSignKey skey_; + std::unique_ptr http_client_; +}; + +} // namespace cloud + +} // namespace util diff --git a/util/cloud/s3_file.cc b/util/cloud/s3_file.cc new file mode 100644 index 00000000..7a012fb7 --- /dev/null +++ b/util/cloud/s3_file.cc @@ -0,0 +1,523 @@ +// Copyright 2023, Roman Gershman. All rights reserved. +// See LICENSE for licensing terms. +// +#include "util/cloud/s3_file.h" + +#include +#include +#include + +#include +#include + +#include "base/logging.h" +#include "util/http/http_common.h" + +using namespace std; + +namespace h2 = boost::beast::http; + +namespace util { +namespace cloud { + +namespace { + +#define RETURN_ON_ERR(x) \ + do { \ + auto __ec = (x); \ + if (__ec) { \ + VLOG(1) << "Error " << __ec << " while calling " #x; \ + return __ec; \ + } \ + } while (0) + +#define RETURN_UNEXPECTED(x) \ + do { \ + auto __ec = (x); \ + if (__ec) { \ + LOG(WARNING) << "Error " << __ec << " while calling " #x; \ + return nonstd::make_unexpected(__ec); \ + } \ + } while (0) + +// AWS requires at least 5MB part size. We use 8MB. +constexpr size_t kMaxPartSize = 1ULL << 23; + +inline void SetRange(size_t from, size_t to, h2::fields* flds) { + string tmp = absl::StrCat("bytes=", from, "-"); + if (to < kuint64max) { + absl::StrAppend(&tmp, to - 1); + } + flds->set(h2::field::range, std::move(tmp)); +} + +inline string_view ToSv(const boost::string_view s) { + return string_view{s.data(), s.size()}; +} + +std::ostream& operator<<(std::ostream& os, const h2::response& msg) { + os << msg.reason() << std::endl; + for (const auto& f : msg) { + os << f.name_string() << " : " << f.value() << std::endl; + } + os << "-------------------------"; + + return os; +} + +error_code DrainResponse(http::Client* client, h2::response_parser* parser) { + char resp[512]; + auto& body = parser->get().body(); + while (!parser->is_done()) { + body.data = resp; + body.size = sizeof(resp); + + http::Client::BoostError ec = client->Recv(parser); + if (ec && ec != h2::error::need_buffer) { + return ec; + } + } + return error_code{}; +} + +error_code ParseXmlStartUpload(std::string_view xml_resp, string* upload_id) { + pugi::xml_document doc; + pugi::xml_parse_result result = doc.load_buffer(xml_resp.data(), xml_resp.size()); + + if (!result) { + LOG(ERROR) << "ParseXmlStartUpload: " << result.description(); + return make_error_code(errc::bad_message); + } + pugi::xml_node upload_res = doc.child("InitiateMultipartUploadResult"); + if (upload_res.type() != pugi::node_element) { + LOG(ERROR) << "Missing InitiateMultipartUploadResult " << xml_resp; + return make_error_code(errc::bad_message); + } + + pugi::xml_node upload_text = upload_res.child("UploadId"); + if (upload_text.type() != pugi::node_element) { + LOG(ERROR) << "Missing UploadId" << upload_res.text(); + return make_error_code(errc::bad_message); + } + *upload_id = upload_text.child_value(); + + return error_code{}; +} + +io::Result InitiateMultiPart(string_view key_path, const AWS& aws, AwsSignKey* skey, + http::Client* client) { + string url("/"); + url.append(key_path).append("?uploads="); + + // Signed params must look like key/value pairs. Instead of handling key-only params + // in the signing code we just pass empty value here. + + h2::request req{h2::verb::post, url, 11}; + h2::response resp; + + req.set(h2::field::host, client->host()); + + RETURN_UNEXPECTED(aws.SendRequest(client, skey, &req, &resp)); + + if (resp.result() != h2::status::ok) { + LOG(ERROR) << "InitiateMultiPart Error: " << resp; + + return nonstd::make_unexpected(make_error_code(errc::io_error)); + } + + string upload_id; + + RETURN_UNEXPECTED(ParseXmlStartUpload(resp.body(), &upload_id)); + + VLOG(1) << "InitiateMultiPart: " << req << "/" << resp << "\nUploadId: " << upload_id; + return upload_id; +} + +class S3ReadFile final : public io::ReadonlyFile { + public: + // does not own pool object, only wraps it with ReadonlyFile interface. + S3ReadFile(const AWS& aws, string read_obj_url, std::unique_ptr client) + : aws_(aws), client_(std::move(client)), read_obj_url_(std::move(read_obj_url)) { + } + + ~S3ReadFile() override; + + // Reads upto length bytes and updates the result to point to the data. + // May use buffer for storing data. In case, EOF reached sets result.size() < length but still + // returns Status::OK. + io::Result Read(size_t offset, const iovec* v, uint32_t len) final; + + std::error_code Open(std::string_view region); + + // releases the system handle for this file. + std::error_code Close() final; + + size_t Size() const final { + return size_; + } + + int Handle() const final { + return -1; + } + + private: + AWS::HttpParser* parser() { + return &parser_; + } + + const AWS& aws_; + std::unique_ptr client_; + + const string read_obj_url_; + + AWS::HttpParser parser_; + size_t size_ = 0, offs_ = 0; + AwsSignKey sign_key_; +}; + +class S3WriteFile : public io::WriteFile { + public: + /** + * @brief Construct a new S3 Write File object. + * + * @param name - aka "s3://somebucket/path_to_obj" + * @param aws - initialized AWS object. + * @param pool - https connection pool connected to google api server. + */ + S3WriteFile(string_view key_name, string_view region, const AWS& aws, + std::unique_ptr client); + + error_code Close() final; + + io::Result WriteSome(const iovec* v, uint32_t len) final; + + private: + size_t FillBody(const uint8* buffer, size_t length); + error_code Upload(); + + AwsSignKey skey_; + const AWS& aws_; + + string upload_id_; + unique_ptr client_; + boost::beast::multi_buffer body_mb_; + vector parts_; +}; + +S3ReadFile::~S3ReadFile() { + Close(); +} + +error_code S3ReadFile::Open(string_view region) { + string url = absl::StrCat("/", read_obj_url_); + h2::request req{h2::verb::get, url, 11}; + req.set(h2::field::host, client_->host()); + + if (offs_) + SetRange(offs_, kuint64max, &req); + + VLOG(1) << "Unsigned request: " << req; + sign_key_ = aws_.GetSignKey(region); + RETURN_ON_ERR(aws_.Handshake(client_.get(), &sign_key_, &req, &parser_)); + + const auto& msg = parser_.get(); + VLOG(1) << "HeaderResp: " << msg.result_int() << " " << msg; + + if (msg.result() == h2::status::not_found) { + RETURN_ON_ERR(DrainResponse(client_.get(), &parser_)); + + return make_error_code(errc::no_such_file_or_directory); + } + + if (msg.result() == h2::status::bad_request) { + return make_error_code(errc::bad_message); + } + + CHECK(parser_.keep_alive()) << "TBD"; + + auto content_len_it = msg.find(h2::field::content_length); + if (content_len_it != msg.end()) { + size_t content_sz = 0; + CHECK(absl::SimpleAtoi(ToSv(content_len_it->value()), &content_sz)); + + if (size_) { + CHECK_EQ(size_, content_sz + offs_) << "File size has changed underneath during reopen"; + } else { + size_ = content_sz; + } + } + + return error_code{}; +} + +error_code S3ReadFile::Close() { + return error_code{}; +} + +io::Result S3ReadFile::Read(size_t offset, const iovec* v, uint32_t len) { + if (offset != offs_) { + return nonstd::make_unexpected(make_error_code(errc::invalid_argument)); + } + + // We can not cache parser() into local var because Open() below recreates the parser instance. + if (parser_.is_done()) { + return 0; + } + + size_t index = 0; + size_t read_sofar = 0; + + while (index < len) { + // We keep body references inside the loop because Open() that might be called here, + // will recreate the parser from the point the connections disconnected. + auto& body = parser()->get().body(); + auto& left_available = body.size; + body.data = v[index].iov_base; + left_available = v[index].iov_len; + + boost::system::error_code ec = client_->Recv(parser()); // decreases left_available. + size_t http_read = v[index].iov_len - left_available; + + if (!ec || ec == h2::error::need_buffer) { // Success + DVLOG(2) << "Read " << http_read << " bytes from " << offset << " with capacity " + << v[index].iov_len << "ec: " << ec; + + CHECK(left_available == 0 || !ec); + + // This check does not happen. See here why: https://github.com/boostorg/beast/issues/1662 + // DCHECK_EQ(sz_read, http_read) << " " << range.size() << "/" << left_available; + offs_ += http_read; + read_sofar += http_read; + ++index; + + continue; + } + + if (ec == h2::error::partial_message) { + offs_ += http_read; + VLOG(1) << "Got partial_message"; + + // advance the destination buffer as well. + read_sofar += http_read; + break; + } + + LOG(ERROR) << "ec: " << ec << "/" << ec.message() << " at " << offset << "/" << size_; + return nonstd::make_unexpected(ec); + } + + return read_sofar; +} + +S3WriteFile::S3WriteFile(string_view name, string_view region, const AWS& aws, + std::unique_ptr client) + : WriteFile(name), aws_(aws), client_(std::move(client)), body_mb_(kMaxPartSize) { + skey_ = aws_.GetSignKey(region); +} + +error_code S3WriteFile::Close() { + error_code ec = Upload(); + if (ec) { + LOG(WARNING) << "S3WriteFile::Close: " << ec.message(); + return ec; + } + + if (parts_.empty()) + return ec; + + DCHECK(!upload_id_.empty()); + VLOG(1) << "Finalizing " << upload_id_ << " with " << parts_.size() << " parts"; + + string url("/"); + url.append(create_file_name_); + + // Signed params must look like key/value pairs. Instead of handling key-only params + // in the signing code we just pass empty value here. + absl::StrAppend(&url, "?uploadId=", upload_id_); + + h2::request req{h2::verb::post, url, 11}; + h2::response resp; + + req.set(h2::field::content_type, http::kXmlMime); + req.set(h2::field::host, client_->host()); + + auto& body = req.body(); + body = R"( +)"; + body.append("\n"); + + for (size_t i = 0; i < parts_.size(); ++i) { + absl::StrAppend(&body, "\"", parts_[i], "\"", i + 1); + absl::StrAppend(&body, "\n"); + } + body.append(""); + + req.prepare_payload(); + + skey_.Sign(string_view{AWS::kUnsignedPayloadSig}, &req); + + ec = client_->Send(req, &resp); + + if (ec) { + LOG(WARNING) << "S3WriteFile::Close: " << req << "/ " << resp << " ec: " << ec; + return ec; + } + + if (resp.result() != h2::status::ok) { + LOG(ERROR) << "S3WriteFile::Close: " << req << "/ " << resp; + + return make_error_code(errc::io_error); + } + parts_.clear(); + + return ec; +} + +io::Result S3WriteFile::WriteSome(const iovec* v, uint32_t len) { + size_t total = 0; + for (size_t i = 0; i < len; ++i) { + size_t len = v[i].iov_len; + const uint8_t* buffer = reinterpret_cast(v[i].iov_base); + + while (len) { + size_t written = FillBody(buffer, len); + total += written; + len -= written; + buffer += written; + if (body_mb_.size() == body_mb_.max_size()) { + RETURN_UNEXPECTED(Upload()); + } + } + } + + return total; +} + +size_t S3WriteFile::FillBody(const uint8* buffer, size_t length) { + size_t prepare_size = std::min(length, body_mb_.max_size() - body_mb_.size()); + auto mbs = body_mb_.prepare(prepare_size); + size_t offs = 0; + for (auto mb : mbs) { + memcpy(mb.data(), buffer + offs, mb.size()); + offs += mb.size(); + } + CHECK_EQ(offs, prepare_size); + body_mb_.commit(prepare_size); + + return offs; +} + +error_code S3WriteFile::Upload() { + size_t body_size = body_mb_.size(); + if (body_size == 0) + return error_code{}; + + h2::request req{h2::verb::put, "", 11}; + req.set(h2::field::content_type, http::kBinMime); + req.set(h2::field::host, client_->host()); + + h2::response resp; + string url("/"); + + // TODO: To figure out why SHA256 is so slow. + // detail::Sha256String(body_mb_, sha256); + absl::StrAppend(&url, create_file_name_); + + if (upload_id_.empty()) { + if (body_size == body_mb_.max_size()) { + auto res = InitiateMultiPart(create_file_name_, aws_, &skey_, client_.get()); + if (!res) { + return res.error(); + } + upload_id_ = std::move(*res); + } + } + + if (!upload_id_.empty()) { + absl::StrAppend(&url, "?uploadId=", upload_id_); + absl::StrAppend(&url, "&partNumber=", parts_.size() + 1); + } + + req.target(url); + req.body() = std::move(body_mb_); + req.prepare_payload(); + + skey_.Sign(string_view{AWS::kUnsignedPayloadSig}, &req); + VLOG(2) << "UploadReq: " << req.base(); + + bool etag_found = false; + http::Client::BoostError bec; + + // Retry several times. During I/O intensive operations we can get ECONNABORTED + // or other weird artifacts. + for (unsigned j = 0; j < 3; ++j) { + bec = client_->Send(req, &resp); + if (bec) { + VLOG(1) << "Upload error: " << bec << " " << bec.message(); + RETURN_ON_ERR(client_->Reconnect()); + + continue; + } + + if (resp.result() != h2::status::ok) { + LOG(ERROR) << "Upload error: " << resp.base(); + return make_error_code(errc::io_error); + } + + VLOG(2) << "UploadResp: " << resp; + + if (resp[h2::field::connection] == "close") { + RETURN_ON_ERR(client_->Reconnect()); + } + + if (upload_id_.empty()) + return error_code{}; + + // sometimes s3 returns empty 200 response without any headers. + auto it = resp.find(h2::field::etag); + if (it != resp.end()) { + auto sv = it->value(); + if (sv.size() <= 2) { + return make_error_code(errc::io_error); + } + sv.remove_prefix(1); + sv.remove_suffix(1); + + // sv.to_string() is missing on older versions of boost.beast. + parts_.emplace_back(string(sv)); + etag_found = true; + break; + } + + VLOG(1) << "No Etag in response: " << req.base() << " " << create_file_name_ << " " + << parts_.size(); + } + + return etag_found ? error_code{} : make_error_code(errc::io_error); +} + +} // namespace + +io::Result OpenS3ReadFile(std::string_view region, string_view path, + const AWS& aws, std::unique_ptr client, + const io::ReadonlyFile::Options& opts) { + CHECK(opts.sequential && client); + VLOG(1) << "OpenS3ReadFile: " << path; + + string read_obj_url{path}; + unique_ptr fl(new S3ReadFile(aws, std::move(read_obj_url), std::move(client))); + + auto ec = fl->Open(region); + if (ec) + return nonstd::make_unexpected(ec); + + return fl.release(); +} + +io::Result OpenS3WriteFile(string_view region, string_view key_path, const AWS& aws, + std::unique_ptr client) { + return new S3WriteFile{key_path, region, aws, std::move(client)}; +} + +} // namespace cloud +} // namespace util diff --git a/util/cloud/s3_file.h b/util/cloud/s3_file.h new file mode 100644 index 00000000..290f0c5a --- /dev/null +++ b/util/cloud/s3_file.h @@ -0,0 +1,24 @@ +// Copyright 2023, Roman Gershman. All rights reserved. +// See LICENSE for licensing terms. +// + +#pragma once + +#include "io/file.h" +#include "util/cloud/aws.h" +#include "util/http/http_client.h" + +namespace util { +namespace cloud { + +io::Result OpenS3ReadFile( + std::string_view region, std::string_view path, const AWS& aws, + std::unique_ptr client, + const io::ReadonlyFile::Options& opts = io::ReadonlyFile::Options{}); + +// Takes ownership over client. +io::Result OpenS3WriteFile(std::string_view region, std::string_view key_path, + const AWS& aws, std::unique_ptr client); + +} // namespace cloud +} // namespace util