mirror of
https://github.com/ducaale/xh.git
synced 2025-05-05 15:32:50 +00:00
Delay initialization of zstd decoder
Fixes a panic for `xh head https://httpbin.dev/zstd`. `ZstdDecoder::new()` returns a `Result`. We used to panic on this, but it needs to be a `Read` error instead, so we can suppress the error for an empty input the way we do for other decoders. Our existing approach couldn't handle this, so I ended up refactoring the system. I think it's cleaner now, though still weird. We now also preserve the original decoder error instead of `.to_string()`ing it, or strip it completely if there was an I/O error. That should improve the error reporting.
This commit is contained in:
parent
9f98ad634a
commit
b4e3fb2012
349
src/decoder.rs
349
src/decoder.rs
@ -1,4 +1,6 @@
|
||||
use std::cell::Cell;
|
||||
use std::io::{self, Read};
|
||||
use std::rc::Rc;
|
||||
use std::str::FromStr;
|
||||
|
||||
use brotli::Decompressor as BrotliDecoder;
|
||||
@ -6,7 +8,7 @@ use flate2::read::{GzDecoder, ZlibDecoder};
|
||||
use reqwest::header::{HeaderMap, CONTENT_ENCODING, CONTENT_LENGTH, TRANSFER_ENCODING};
|
||||
use ruzstd::{FrameDecoder, StreamingDecoder as ZstdDecoder};
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum CompressionType {
|
||||
Gzip,
|
||||
Deflate,
|
||||
@ -52,89 +54,111 @@ pub fn get_compression_type(headers: &HeaderMap) -> Option<CompressionType> {
|
||||
compression_type
|
||||
}
|
||||
|
||||
struct InnerReader<R: Read> {
|
||||
reader: R,
|
||||
has_read_data: bool,
|
||||
has_errored: bool,
|
||||
/// A wrapper that checks whether an error is an I/O error or a decoding error.
|
||||
///
|
||||
/// The main purpose of this is to suppress decoding errors that happen because
|
||||
/// of an empty input. This is behavior we inherited from HTTPie.
|
||||
///
|
||||
/// It's load-bearing in the case of HEAD requests, where responses don't have a
|
||||
/// body but may declare a Content-Encoding. We also treat other requests like this,
|
||||
/// perhaps unwisely.
|
||||
///
|
||||
/// As a side benefit we make I/O errors more focused by stripping decoding errors.
|
||||
///
|
||||
/// The reader is structured like this:
|
||||
///
|
||||
/// OuterReader ───────┐
|
||||
/// compression codec ├── [Status]
|
||||
/// [InnerReader] ──────┘
|
||||
/// underlying I/O
|
||||
///
|
||||
/// The shared Status object is used to communicate.
|
||||
struct OuterReader<'a> {
|
||||
decoder: Box<dyn Read + 'a>,
|
||||
status: Option<Rc<Status>>,
|
||||
}
|
||||
|
||||
impl<R: Read> InnerReader<R> {
|
||||
fn new(reader: R) -> Self {
|
||||
InnerReader {
|
||||
reader,
|
||||
has_read_data: false,
|
||||
has_errored: false,
|
||||
struct Status {
|
||||
has_read_data: Cell<bool>,
|
||||
read_error: Cell<Option<io::Error>>,
|
||||
error_msg: &'static str,
|
||||
}
|
||||
|
||||
impl Read for OuterReader<'_> {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
match self.decoder.read(buf) {
|
||||
Ok(n) => Ok(n),
|
||||
Err(err) => {
|
||||
let Some(ref status) = self.status else {
|
||||
// No decoder, pass on as is
|
||||
return Err(err);
|
||||
};
|
||||
match status.read_error.take() {
|
||||
// If an I/O error happened, return that.
|
||||
Some(read_error) => Err(read_error),
|
||||
// If the input was empty, ignore the decoder error.
|
||||
None if !status.has_read_data.get() => Ok(0),
|
||||
// Otherwise, decorate the decoder error with a message.
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
DecodeError {
|
||||
msg: status.error_msg,
|
||||
err,
|
||||
},
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct InnerReader<R: Read> {
|
||||
reader: R,
|
||||
status: Rc<Status>,
|
||||
}
|
||||
|
||||
impl<R: Read> Read for InnerReader<R> {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
self.has_errored = false;
|
||||
self.status.read_error.set(None);
|
||||
match self.reader.read(buf) {
|
||||
Ok(0) => Ok(0),
|
||||
Ok(len) => {
|
||||
self.has_read_data = true;
|
||||
self.status.has_read_data.set(true);
|
||||
Ok(len)
|
||||
}
|
||||
Err(e) => {
|
||||
self.has_errored = true;
|
||||
Err(e)
|
||||
Err(err) => {
|
||||
// Store the real error and return a placeholder.
|
||||
// The placeholder is intercepted and replaced by the real error
|
||||
// before leaving this module.
|
||||
// We store the whole error instead of setting a flag because of zstd:
|
||||
// - ZstdDecoder::new() fails with a custom error type and it's hard
|
||||
// to extract the underlying io::Error
|
||||
// - ZstdDecoder::read() (unlike the other decoders) wraps custom errors
|
||||
// around the underlying io::Error
|
||||
let msg = err.to_string();
|
||||
let kind = err.kind();
|
||||
self.status.read_error.set(Some(err));
|
||||
Err(io::Error::new(kind, msg))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
enum Decoder<R: Read> {
|
||||
PlainText(InnerReader<R>),
|
||||
Gzip(GzDecoder<InnerReader<R>>),
|
||||
Deflate(ZlibDecoder<InnerReader<R>>),
|
||||
Brotli(BrotliDecoder<InnerReader<R>>),
|
||||
Zstd(ZstdDecoder<InnerReader<R>, FrameDecoder>),
|
||||
#[derive(Debug)]
|
||||
struct DecodeError {
|
||||
msg: &'static str,
|
||||
err: io::Error,
|
||||
}
|
||||
|
||||
impl<R: Read> Read for Decoder<R> {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
match self {
|
||||
Decoder::PlainText(decoder) => decoder.read(buf),
|
||||
Decoder::Gzip(decoder) => match decoder.read(buf) {
|
||||
Ok(n) => Ok(n),
|
||||
Err(e) if decoder.get_ref().has_errored => Err(e),
|
||||
Err(_) if !decoder.get_ref().has_read_data => Ok(0),
|
||||
Err(e) => Err(io::Error::new(
|
||||
e.kind(),
|
||||
format!("error decoding gzip response body: {}", e),
|
||||
)),
|
||||
},
|
||||
Decoder::Deflate(decoder) => match decoder.read(buf) {
|
||||
Ok(n) => Ok(n),
|
||||
Err(e) if decoder.get_ref().has_errored => Err(e),
|
||||
Err(_) if !decoder.get_ref().has_read_data => Ok(0),
|
||||
Err(e) => Err(io::Error::new(
|
||||
e.kind(),
|
||||
format!("error decoding deflate response body: {}", e),
|
||||
)),
|
||||
},
|
||||
Decoder::Brotli(decoder) => match decoder.read(buf) {
|
||||
Ok(n) => Ok(n),
|
||||
Err(e) if decoder.get_ref().has_errored => Err(e),
|
||||
Err(_) if !decoder.get_ref().has_read_data => Ok(0),
|
||||
Err(e) => Err(io::Error::new(
|
||||
e.kind(),
|
||||
format!("error decoding brotli response body: {}", e),
|
||||
)),
|
||||
},
|
||||
Decoder::Zstd(decoder) => match decoder.read(buf) {
|
||||
Ok(n) => Ok(n),
|
||||
Err(e) if decoder.get_ref().has_errored => Err(e),
|
||||
Err(_) if !decoder.get_ref().has_read_data => Ok(0),
|
||||
Err(e) => Err(io::Error::new(
|
||||
e.kind(),
|
||||
format!("error decoding zstd response body: {}", e),
|
||||
)),
|
||||
},
|
||||
}
|
||||
impl std::fmt::Display for DecodeError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(self.msg)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for DecodeError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
Some(&self.err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -142,18 +166,70 @@ pub fn decompress(
|
||||
reader: &mut impl Read,
|
||||
compression_type: Option<CompressionType>,
|
||||
) -> impl Read + '_ {
|
||||
let reader = InnerReader::new(reader);
|
||||
match compression_type {
|
||||
Some(CompressionType::Gzip) => Decoder::Gzip(GzDecoder::new(reader)),
|
||||
Some(CompressionType::Deflate) => Decoder::Deflate(ZlibDecoder::new(reader)),
|
||||
Some(CompressionType::Brotli) => Decoder::Brotli(BrotliDecoder::new(reader, 4096)),
|
||||
Some(CompressionType::Zstd) => Decoder::Zstd(ZstdDecoder::new(reader).unwrap()),
|
||||
None => Decoder::PlainText(reader),
|
||||
let Some(compression_type) = compression_type else {
|
||||
return OuterReader {
|
||||
decoder: Box::new(reader),
|
||||
status: None,
|
||||
};
|
||||
};
|
||||
|
||||
let status = Rc::new(Status {
|
||||
has_read_data: Cell::new(false),
|
||||
read_error: Cell::new(None),
|
||||
error_msg: match compression_type {
|
||||
CompressionType::Gzip => "error decoding gzip response body",
|
||||
CompressionType::Deflate => "error decoding deflate response body",
|
||||
CompressionType::Brotli => "error decoding brotli response body",
|
||||
CompressionType::Zstd => "error decoding zstd response body",
|
||||
},
|
||||
});
|
||||
let reader = InnerReader {
|
||||
reader,
|
||||
status: Rc::clone(&status),
|
||||
};
|
||||
OuterReader {
|
||||
decoder: match compression_type {
|
||||
CompressionType::Gzip => Box::new(GzDecoder::new(reader)),
|
||||
CompressionType::Deflate => Box::new(ZlibDecoder::new(reader)),
|
||||
CompressionType::Brotli => Box::new(BrotliDecoder::new(reader, 4096)),
|
||||
CompressionType::Zstd => Box::new(LazyZstdDecoder::Uninit(Some(reader))),
|
||||
},
|
||||
status: Some(status),
|
||||
}
|
||||
}
|
||||
|
||||
/// [ZstdDecoder] reads from its input during construction.
|
||||
///
|
||||
/// We need to delay construction until [Read] so read errors stay read errors.
|
||||
enum LazyZstdDecoder<R: Read> {
|
||||
Uninit(Option<R>),
|
||||
Init(ZstdDecoder<R, FrameDecoder>),
|
||||
}
|
||||
|
||||
impl<R: Read> Read for LazyZstdDecoder<R> {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
match self {
|
||||
LazyZstdDecoder::Uninit(reader) => match reader.take() {
|
||||
Some(reader) => match ZstdDecoder::new(reader) {
|
||||
Ok(decoder) => {
|
||||
*self = LazyZstdDecoder::Init(decoder);
|
||||
self.read(buf)
|
||||
}
|
||||
Err(err) => Err(io::Error::other(err)),
|
||||
},
|
||||
// We seem to get here in --stream mode because another layer tries
|
||||
// to read again after Ok(0).
|
||||
None => Err(io::Error::other("failed to construct ZstdDecoder")),
|
||||
},
|
||||
LazyZstdDecoder::Init(streaming_decoder) => streaming_decoder.read(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::error::Error;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
@ -167,7 +243,7 @@ mod tests {
|
||||
Err(e) => {
|
||||
assert!(e
|
||||
.to_string()
|
||||
.starts_with("error decoding gzip response body:"))
|
||||
.starts_with("error decoding gzip response body"))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -195,23 +271,128 @@ mod tests {
|
||||
#[test]
|
||||
fn interrupts_are_handled_gracefully() {
|
||||
struct InterruptedReader {
|
||||
init: bool,
|
||||
step: u8,
|
||||
}
|
||||
impl Read for InterruptedReader {
|
||||
fn read(&mut self, _buf: &mut [u8]) -> io::Result<usize> {
|
||||
if !self.init {
|
||||
self.init = true;
|
||||
Err(io::Error::new(io::ErrorKind::Interrupted, "interrupted"))
|
||||
} else {
|
||||
Ok(0)
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
self.step += 1;
|
||||
match self.step {
|
||||
1 => Read::read(&mut b"abc".as_slice(), buf),
|
||||
2 => Err(io::Error::new(io::ErrorKind::Interrupted, "interrupted")),
|
||||
3 => Read::read(&mut b"def".as_slice(), buf),
|
||||
_ => Ok(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut base_reader = InterruptedReader { init: false };
|
||||
let mut reader = decompress(&mut base_reader, Some(CompressionType::Gzip));
|
||||
let mut buffer = Vec::new();
|
||||
reader.read_to_end(&mut buffer).unwrap();
|
||||
assert_eq!(buffer, b"");
|
||||
for compression_type in [
|
||||
None,
|
||||
Some(CompressionType::Brotli),
|
||||
Some(CompressionType::Deflate),
|
||||
Some(CompressionType::Gzip),
|
||||
Some(CompressionType::Zstd),
|
||||
] {
|
||||
let mut base_reader = InterruptedReader { step: 0 };
|
||||
let mut reader = decompress(&mut base_reader, compression_type);
|
||||
let mut buffer = Vec::with_capacity(16);
|
||||
let res = reader.read_to_end(&mut buffer);
|
||||
if compression_type.is_none() {
|
||||
res.unwrap();
|
||||
assert_eq!(buffer, b"abcdef");
|
||||
} else {
|
||||
res.unwrap_err();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_inputs_do_not_cause_errors() {
|
||||
for compression_type in [
|
||||
None,
|
||||
Some(CompressionType::Brotli),
|
||||
Some(CompressionType::Deflate),
|
||||
Some(CompressionType::Gzip),
|
||||
Some(CompressionType::Zstd),
|
||||
] {
|
||||
let mut input: &[u8] = b"";
|
||||
let mut reader = decompress(&mut input, compression_type);
|
||||
let mut buf = Vec::new();
|
||||
reader.read_to_end(&mut buf).unwrap();
|
||||
assert_eq!(buf, b"");
|
||||
|
||||
// Must accept repeated read attempts after EOF (this happens with --stream)
|
||||
for _ in 0..10 {
|
||||
reader.read_to_end(&mut buf).unwrap();
|
||||
assert_eq!(buf, b"");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_errors_keep_their_context() {
|
||||
#[derive(Debug)]
|
||||
struct SpecialErr;
|
||||
impl std::fmt::Display for SpecialErr {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{:?}", self)
|
||||
}
|
||||
}
|
||||
impl std::error::Error for SpecialErr {}
|
||||
|
||||
struct SadReader;
|
||||
impl Read for SadReader {
|
||||
fn read(&mut self, _buf: &mut [u8]) -> io::Result<usize> {
|
||||
Err(io::Error::new(io::ErrorKind::WouldBlock, SpecialErr))
|
||||
}
|
||||
}
|
||||
|
||||
for compression_type in [
|
||||
None,
|
||||
Some(CompressionType::Brotli),
|
||||
Some(CompressionType::Deflate),
|
||||
Some(CompressionType::Gzip),
|
||||
Some(CompressionType::Zstd),
|
||||
] {
|
||||
let mut input = SadReader;
|
||||
let mut reader = decompress(&mut input, compression_type);
|
||||
let mut buf = Vec::new();
|
||||
let err = reader.read_to_end(&mut buf).unwrap_err();
|
||||
assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
|
||||
err.get_ref().unwrap().downcast_ref::<SpecialErr>().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn true_decode_errors_are_preserved() {
|
||||
for compression_type in [
|
||||
CompressionType::Brotli,
|
||||
CompressionType::Deflate,
|
||||
CompressionType::Gzip,
|
||||
CompressionType::Zstd,
|
||||
] {
|
||||
let mut input: &[u8] = b"bad";
|
||||
let mut reader = decompress(&mut input, Some(compression_type));
|
||||
let mut buf = Vec::new();
|
||||
let err = reader.read_to_end(&mut buf).unwrap_err();
|
||||
|
||||
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
|
||||
let decode_err = err
|
||||
.get_ref()
|
||||
.unwrap()
|
||||
.downcast_ref::<DecodeError>()
|
||||
.unwrap();
|
||||
let real_err = decode_err.source().unwrap();
|
||||
let real_err = real_err.downcast_ref::<io::Error>().unwrap();
|
||||
|
||||
// All four decoders make a different choice here...
|
||||
// Still the easiest way to check that we're preserving the error
|
||||
let expected_kind = match compression_type {
|
||||
CompressionType::Gzip => io::ErrorKind::UnexpectedEof,
|
||||
CompressionType::Deflate => io::ErrorKind::InvalidInput,
|
||||
CompressionType::Brotli => io::ErrorKind::InvalidData,
|
||||
CompressionType::Zstd => io::ErrorKind::Other,
|
||||
};
|
||||
assert_eq!(real_err.kind(), expected_kind);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
53
tests/cli.rs
53
tests/cli.rs
@ -3570,6 +3570,59 @@ fn empty_response_with_content_encoding_and_content_length() {
|
||||
"#});
|
||||
}
|
||||
|
||||
/// Regression test: this used to crash because ZstdDecoder::new() is fallible
|
||||
#[test]
|
||||
fn empty_zstd_response_with_content_encoding_and_content_length() {
|
||||
let server = server::http(|_req| async move {
|
||||
hyper::Response::builder()
|
||||
.header("date", "N/A")
|
||||
.header("content-encoding", "zstd")
|
||||
.header("content-length", "100")
|
||||
.body("".into())
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
get_command()
|
||||
.arg("head")
|
||||
.arg(server.base_url())
|
||||
.assert()
|
||||
.stdout(indoc! {r#"
|
||||
HTTP/1.1 200 OK
|
||||
Content-Encoding: zstd
|
||||
Content-Length: 100
|
||||
Date: N/A
|
||||
|
||||
|
||||
"#});
|
||||
}
|
||||
|
||||
/// After an initial fix this scenario still crashed
|
||||
#[test]
|
||||
fn streaming_empty_zstd_response_with_content_encoding_and_content_length() {
|
||||
let server = server::http(|_req| async move {
|
||||
hyper::Response::builder()
|
||||
.header("date", "N/A")
|
||||
.header("content-encoding", "zstd")
|
||||
.header("content-length", "100")
|
||||
.body("".into())
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
get_command()
|
||||
.arg("--stream")
|
||||
.arg("head")
|
||||
.arg(server.base_url())
|
||||
.assert()
|
||||
.stdout(indoc! {r#"
|
||||
HTTP/1.1 200 OK
|
||||
Content-Encoding: zstd
|
||||
Content-Length: 100
|
||||
Date: N/A
|
||||
|
||||
|
||||
"#});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_meta() {
|
||||
let server = server::http(|_req| async move {
|
||||
|
Loading…
x
Reference in New Issue
Block a user