Skip to content

Commit 6e022b7

Browse files
authored
fix(s3): request-retry doesn't support streaming bodies (#83)
1 parent 78954c7 commit 6e022b7

3 files changed

Lines changed: 81 additions & 139 deletions

File tree

Cargo.lock

Lines changed: 0 additions & 72 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ path = "src/bin/mangen.rs"
1818
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
1919

2020
[dependencies]
21-
anyhow = "1.0.79"
21+
anyhow = "1.0"
2222
clap = { version = "4.4.18", features = ["derive", "env"] }
2323
clap_complete = "4"
2424
log = "0.4.20"
@@ -37,8 +37,6 @@ reqwest = { version = "0.13", default-features = false, features = [
3737
"stream",
3838
"rustls",
3939
] }
40-
reqwest-middleware = "0.5"
41-
reqwest-retry = "0.9"
4240
time = { version = "0.3.36", features = ["serde-well-known"] }
4341
tokio = { version = "1.40.0", features = ["full"] }
4442
tokio-util = "0.7.11"

src/api.rs

Lines changed: 80 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ use async_trait::async_trait;
99
use futures::StreamExt;
1010
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
1111
use reqwest::{Body, Client, StatusCode};
12-
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
13-
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
1412
use serde::{Deserialize, Serialize};
1513
use serde_with::skip_serializing_none;
1614
use std::collections::HashMap;
@@ -85,7 +83,6 @@ pub struct RapiReqwestClient {
8583
base_url: String,
8684
api_key: String,
8785
client: Client,
88-
s3_client: ClientWithMiddleware,
8986
}
9087

9188
impl RapiReqwestClient {
@@ -104,11 +101,6 @@ impl RapiReqwestClient {
104101

105102
impl Default for RapiReqwestClient {
106103
fn default() -> Self {
107-
let reqwest_s3_client = Client::builder()
108-
.pool_idle_timeout(Some(Duration::from_secs(20)))
109-
.pool_max_idle_per_host(16)
110-
.build()
111-
.unwrap();
112104
Self {
113105
base_url: String::from("https:://cloud.marathonlabs.io/api"),
114106
api_key: "".into(),
@@ -117,11 +109,6 @@ impl Default for RapiReqwestClient {
117109
.pool_max_idle_per_host(16)
118110
.build()
119111
.unwrap(),
120-
s3_client: ClientBuilder::new(reqwest_s3_client)
121-
.with(RetryTransientMiddleware::new_with_policy(
122-
ExponentialBackoff::builder().build_with_max_retries(3),
123-
))
124-
.build(),
125112
}
126113
}
127114
}
@@ -186,7 +173,6 @@ impl RapiClient for RapiReqwestClient {
186173
s3_test_app_path = Some(
187174
upload_to_s3(
188175
&self.client,
189-
&self.s3_client,
190176
self.base_url.clone(),
191177
self.api_key.clone(),
192178
test_app.path.clone(),
@@ -201,7 +187,6 @@ impl RapiClient for RapiReqwestClient {
201187
s3_app_path = Some(
202188
upload_to_s3(
203189
&self.client,
204-
&self.s3_client,
205190
self.base_url.clone(),
206191
self.api_key.clone(),
207192
app.path.clone(),
@@ -217,7 +202,6 @@ impl RapiClient for RapiReqwestClient {
217202
for app_bundle in app_bundles {
218203
let s3_app_path = upload_to_s3(
219204
&self.client,
220-
&self.s3_client,
221205
self.base_url.clone(),
222206
self.api_key.clone(),
223207
app_bundle.application.path.clone(),
@@ -227,7 +211,6 @@ impl RapiClient for RapiReqwestClient {
227211

228212
let s3_test_app_path = upload_to_s3(
229213
&self.client,
230-
&self.s3_client,
231214
self.base_url.clone(),
232215
self.api_key.clone(),
233216
app_bundle.test_application.path.clone(),
@@ -249,7 +232,6 @@ impl RapiClient for RapiReqwestClient {
249232
for bundle in bundles {
250233
let s3_test_app_path = upload_to_s3(
251234
&self.client,
252-
&self.s3_client,
253235
self.base_url.clone(),
254236
self.api_key.clone(),
255237
bundle.test_application.path.clone(),
@@ -486,22 +468,35 @@ async fn api_error_adapter(response: reqwest::Response) -> Result<reqwest::Respo
486468
}
487469
}
488470

471+
fn retryable_io_error(error: &std::io::Error) -> bool {
472+
match error.kind() {
473+
std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::ConnectionAborted => true,
474+
_ => false,
475+
}
476+
}
477+
478+
fn get_source_error_type<T: std::error::Error + 'static>(
479+
err: &dyn std::error::Error,
480+
) -> Option<&T> {
481+
let mut source = err.source();
482+
483+
while let Some(err) = source {
484+
if let Some(err) = err.downcast_ref::<T>() {
485+
return Some(err);
486+
}
487+
488+
source = err.source();
489+
}
490+
None
491+
}
492+
489493
async fn upload_to_s3(
490494
client: &Client,
491-
s3_client: &ClientWithMiddleware,
492495
base_url_with_params: String,
493496
api_key: String,
494497
file_path: PathBuf,
495498
no_progress_bar: bool,
496499
) -> Result<String> {
497-
// Open file
498-
let file = File::open(&file_path)
499-
.await
500-
.map_err(|error| InputError::OpenFileFailure {
501-
path: file_path.clone(),
502-
error,
503-
})?;
504-
505500
// Extract filename from PathBuf
506501
let file_name = file_path
507502
.file_name()
@@ -527,53 +522,74 @@ async fn upload_to_s3(
527522
.map_err(|error| ApiError::DeserializationFailure { error })?;
528523

529524
// Progress stuff
530-
let file_total_size = file.metadata().await?.len();
531-
let mut file_reader = ReaderStream::new(file);
532525
let mut multi_progress: Option<MultiProgress> = if !no_progress_bar {
533526
Some(MultiProgress::new())
534527
} else {
535528
None
536529
};
537-
let file_progress_bar;
538-
let file_body;
539-
if !no_progress_bar {
540-
let sty = ProgressStyle::with_template(
530+
531+
let mut retries = 3;
532+
while retries > 0 {
533+
let file = File::open(&file_path)
534+
.await
535+
.map_err(|error| InputError::OpenFileFailure {
536+
path: file_path.clone(),
537+
error,
538+
})?;
539+
let file_total_size = file.metadata().await?.len();
540+
let mut file_reader = ReaderStream::new(file);
541+
let file_progress_bar;
542+
let file_body;
543+
if !no_progress_bar {
544+
let sty = ProgressStyle::with_template(
541545
"{spinner:.blue} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})"
542-
)
543-
.unwrap()
544-
.progress_chars("#>-");
545-
546-
let pb = ProgressBar::new(file_total_size);
547-
pb.enable_steady_tick(Duration::from_millis(80));
548-
file_progress_bar = multi_progress.as_mut().unwrap().add(pb);
549-
file_progress_bar.set_style(sty.clone());
550-
let mut file_progress = 0u64;
551-
let file_stream = async_stream::stream! {
552-
while let Some(chunk) = file_reader.next().await {
553-
let file_progress_bar = file_progress_bar.clone();
554-
if let Ok(chunk) = &chunk {
555-
let new = min(file_progress + (chunk.len() as u64), file_total_size);
556-
file_progress = new;
557-
file_progress_bar.set_position(new);
558-
if file_progress >= file_total_size {
559-
file_progress_bar.finish_and_clear();
546+
)
547+
.unwrap()
548+
.progress_chars("#>-");
549+
550+
let pb = ProgressBar::new(file_total_size);
551+
pb.enable_steady_tick(Duration::from_millis(80));
552+
file_progress_bar = multi_progress.as_mut().unwrap().add(pb);
553+
file_progress_bar.set_style(sty.clone());
554+
let mut file_progress = 0u64;
555+
let file_stream = async_stream::stream! {
556+
while let Some(chunk) = file_reader.next().await {
557+
let file_progress_bar = file_progress_bar.clone();
558+
if let Ok(chunk) = &chunk {
559+
let new = min(file_progress + (chunk.len() as u64), file_total_size);
560+
file_progress = new;
561+
file_progress_bar.set_position(new);
562+
if file_progress >= file_total_size {
563+
file_progress_bar.finish_and_clear();
564+
}
560565
}
566+
yield chunk;
567+
}
568+
};
569+
file_body = Body::wrap_stream(file_stream);
570+
} else {
571+
file_body = Body::wrap_stream(file_reader);
572+
}
573+
let s3_response = client
574+
.put(upload_url_response.url.clone())
575+
.header("Content-Length", file_total_size)
576+
.body(file_body)
577+
.send()
578+
.await;
579+
580+
//Check if it's a retriable IO error
581+
if let Err(err) = &s3_response {
582+
if let Some(io_error) = get_source_error_type::<std::io::Error>(err) {
583+
if retryable_io_error(io_error) {
584+
retries -= 1;
585+
continue;
561586
}
562-
yield chunk;
563587
}
564-
};
565-
file_body = Body::wrap_stream(file_stream);
566-
} else {
567-
file_body = Body::wrap_stream(file_reader);
568-
}
588+
}
569589

570-
let s3_response = s3_client
571-
.put(upload_url_response.url.clone())
572-
.header("Content-Length", file_total_size)
573-
.body(file_body)
574-
.send()
575-
.await?;
576-
api_error_adapter(s3_response).await?;
590+
api_error_adapter(s3_response?).await?;
591+
break;
592+
}
577593

578594
Ok(upload_url_response.file_path.clone())
579595
}

0 commit comments

Comments
 (0)