Merge pull request #411 from blyxxyz/zstd-lazy-init

Fix crash on empty zstd response body
This commit is contained in:
Mohamed Daahir 2025-03-16 22:24:52 +00:00 committed by GitHub
commit 24afadb354
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 340 additions and 90 deletions

View File

@ -1,4 +1,6 @@
use std::cell::Cell;
use std::io::{self, Read};
use std::rc::Rc;
use std::str::FromStr;
use brotli::Decompressor as BrotliDecoder;
@ -6,7 +8,7 @@ use flate2::read::{GzDecoder, ZlibDecoder};
use reqwest::header::{HeaderMap, CONTENT_ENCODING, CONTENT_LENGTH, TRANSFER_ENCODING};
use ruzstd::{FrameDecoder, StreamingDecoder as ZstdDecoder};
#[derive(Debug)]
#[derive(Debug, Clone, Copy)]
pub enum CompressionType {
Gzip,
Deflate,
@ -18,7 +20,11 @@ impl FromStr for CompressionType {
type Err = anyhow::Error;
fn from_str(value: &str) -> anyhow::Result<CompressionType> {
match value {
"gzip" => Ok(CompressionType::Gzip),
// RFC 2616 section 3.5:
// For compatibility with previous implementations of HTTP,
// applications SHOULD consider "x-gzip" and "x-compress" to be
// equivalent to "gzip" and "compress" respectively.
"gzip" | "x-gzip" => Ok(CompressionType::Gzip),
"deflate" => Ok(CompressionType::Deflate),
"br" => Ok(CompressionType::Brotli),
"zstd" => Ok(CompressionType::Zstd),
@ -52,89 +58,114 @@ pub fn get_compression_type(headers: &HeaderMap) -> Option<CompressionType> {
compression_type
}
struct InnerReader<R: Read> {
reader: R,
has_read_data: bool,
has_errored: bool,
/// A wrapper that checks whether an error is an I/O error or a decoding error.
///
/// The main purpose of this is to suppress decoding errors that happen because
/// of an empty input. This is behavior we inherited from HTTPie.
///
/// It's load-bearing in the case of HEAD requests, where responses don't have a
/// body but may declare a Content-Encoding.
///
/// We also treat other empty response bodies like this, regardless of the request
/// method. This matches all the user agents I tried (reqwest, requests/HTTPie, curl,
/// wget, Firefox, Chromium) but I don't know if it's prescribed by any RFC.
///
/// As a side benefit we make I/O errors more focused by stripping decoding errors.
///
/// The reader is structured like this:
///
/// OuterReader ───────┐
/// compression codec ├── [Status]
/// [InnerReader] ──────┘
/// underlying I/O
///
/// The shared Status object is used to communicate.
struct OuterReader<'a> {
decoder: Box<dyn Read + 'a>,
status: Option<Rc<Status>>,
}
impl<R: Read> InnerReader<R> {
fn new(reader: R) -> Self {
InnerReader {
reader,
has_read_data: false,
has_errored: false,
struct Status {
has_read_data: Cell<bool>,
read_error: Cell<Option<io::Error>>,
error_msg: &'static str,
}
impl Read for OuterReader<'_> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self.decoder.read(buf) {
Ok(n) => Ok(n),
Err(err) => {
let Some(ref status) = self.status else {
// No decoder, pass on as is
return Err(err);
};
match status.read_error.take() {
// If an I/O error happened, return that.
Some(read_error) => Err(read_error),
// If the input was empty, ignore the decoder error.
None if !status.has_read_data.get() => Ok(0),
// Otherwise, decorate the decoder error with a message.
None => Err(io::Error::new(
io::ErrorKind::InvalidData,
DecodeError {
msg: status.error_msg,
err,
},
)),
}
}
}
}
}
struct InnerReader<R: Read> {
reader: R,
status: Rc<Status>,
}
impl<R: Read> Read for InnerReader<R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.has_errored = false;
self.status.read_error.set(None);
match self.reader.read(buf) {
Ok(0) => Ok(0),
Ok(len) => {
self.has_read_data = true;
self.status.has_read_data.set(true);
Ok(len)
}
Err(e) => {
self.has_errored = true;
Err(e)
Err(err) => {
// Store the real error and return a placeholder.
// The placeholder is intercepted and replaced by the real error
// before leaving this module.
// We store the whole error instead of setting a flag because of zstd:
// - ZstdDecoder::new() fails with a custom error type and it's hard
// to extract the underlying io::Error
// - ZstdDecoder::read() (unlike the other decoders) wraps custom errors
// around the underlying io::Error
let msg = err.to_string();
let kind = err.kind();
self.status.read_error.set(Some(err));
Err(io::Error::new(kind, msg))
}
}
}
}
#[allow(clippy::large_enum_variant)]
enum Decoder<R: Read> {
PlainText(InnerReader<R>),
Gzip(GzDecoder<InnerReader<R>>),
Deflate(ZlibDecoder<InnerReader<R>>),
Brotli(BrotliDecoder<InnerReader<R>>),
Zstd(ZstdDecoder<InnerReader<R>, FrameDecoder>),
#[derive(Debug)]
struct DecodeError {
msg: &'static str,
err: io::Error,
}
impl<R: Read> Read for Decoder<R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
Decoder::PlainText(decoder) => decoder.read(buf),
Decoder::Gzip(decoder) => match decoder.read(buf) {
Ok(n) => Ok(n),
Err(e) if decoder.get_ref().has_errored => Err(e),
Err(_) if !decoder.get_ref().has_read_data => Ok(0),
Err(e) => Err(io::Error::new(
e.kind(),
format!("error decoding gzip response body: {}", e),
)),
},
Decoder::Deflate(decoder) => match decoder.read(buf) {
Ok(n) => Ok(n),
Err(e) if decoder.get_ref().has_errored => Err(e),
Err(_) if !decoder.get_ref().has_read_data => Ok(0),
Err(e) => Err(io::Error::new(
e.kind(),
format!("error decoding deflate response body: {}", e),
)),
},
Decoder::Brotli(decoder) => match decoder.read(buf) {
Ok(n) => Ok(n),
Err(e) if decoder.get_ref().has_errored => Err(e),
Err(_) if !decoder.get_ref().has_read_data => Ok(0),
Err(e) => Err(io::Error::new(
e.kind(),
format!("error decoding brotli response body: {}", e),
)),
},
Decoder::Zstd(decoder) => match decoder.read(buf) {
Ok(n) => Ok(n),
Err(e) if decoder.get_ref().has_errored => Err(e),
Err(_) if !decoder.get_ref().has_read_data => Ok(0),
Err(e) => Err(io::Error::new(
e.kind(),
format!("error decoding zstd response body: {}", e),
)),
},
}
impl std::fmt::Display for DecodeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.msg)
}
}
impl std::error::Error for DecodeError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(&self.err)
}
}
@ -142,18 +173,71 @@ pub fn decompress(
reader: &mut impl Read,
compression_type: Option<CompressionType>,
) -> impl Read + '_ {
let reader = InnerReader::new(reader);
match compression_type {
Some(CompressionType::Gzip) => Decoder::Gzip(GzDecoder::new(reader)),
Some(CompressionType::Deflate) => Decoder::Deflate(ZlibDecoder::new(reader)),
Some(CompressionType::Brotli) => Decoder::Brotli(BrotliDecoder::new(reader, 4096)),
Some(CompressionType::Zstd) => Decoder::Zstd(ZstdDecoder::new(reader).unwrap()),
None => Decoder::PlainText(reader),
let Some(compression_type) = compression_type else {
return OuterReader {
decoder: Box::new(reader),
status: None,
};
};
let status = Rc::new(Status {
has_read_data: Cell::new(false),
read_error: Cell::new(None),
error_msg: match compression_type {
CompressionType::Gzip => "error decoding gzip response body",
CompressionType::Deflate => "error decoding deflate response body",
CompressionType::Brotli => "error decoding brotli response body",
CompressionType::Zstd => "error decoding zstd response body",
},
});
let reader = InnerReader {
reader,
status: Rc::clone(&status),
};
OuterReader {
decoder: match compression_type {
CompressionType::Gzip => Box::new(GzDecoder::new(reader)),
CompressionType::Deflate => Box::new(ZlibDecoder::new(reader)),
// 32K is the default buffer size for gzip and deflate
CompressionType::Brotli => Box::new(BrotliDecoder::new(reader, 32 * 1024)),
CompressionType::Zstd => Box::new(LazyZstdDecoder::Uninit(Some(reader))),
},
status: Some(status),
}
}
/// [ZstdDecoder] reads from its input during construction.
///
/// We need to delay construction until [Read] so read errors stay read errors.
enum LazyZstdDecoder<R: Read> {
Uninit(Option<R>),
Init(ZstdDecoder<R, FrameDecoder>),
}
impl<R: Read> Read for LazyZstdDecoder<R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
LazyZstdDecoder::Uninit(reader) => match reader.take() {
Some(reader) => match ZstdDecoder::new(reader) {
Ok(decoder) => {
*self = LazyZstdDecoder::Init(decoder);
self.read(buf)
}
Err(err) => Err(io::Error::other(err)),
},
// We seem to get here in --stream mode because another layer tries
// to read again after Ok(0).
None => Err(io::Error::other("failed to construct ZstdDecoder")),
},
LazyZstdDecoder::Init(streaming_decoder) => streaming_decoder.read(buf),
}
}
}
#[cfg(test)]
mod tests {
use std::error::Error;
use super::*;
#[test]
@ -167,7 +251,7 @@ mod tests {
Err(e) => {
assert!(e
.to_string()
.starts_with("error decoding gzip response body:"))
.starts_with("error decoding gzip response body"))
}
}
}
@ -195,23 +279,128 @@ mod tests {
#[test]
fn interrupts_are_handled_gracefully() {
struct InterruptedReader {
init: bool,
step: u8,
}
impl Read for InterruptedReader {
fn read(&mut self, _buf: &mut [u8]) -> io::Result<usize> {
if !self.init {
self.init = true;
Err(io::Error::new(io::ErrorKind::Interrupted, "interrupted"))
} else {
Ok(0)
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.step += 1;
match self.step {
1 => Read::read(&mut b"abc".as_slice(), buf),
2 => Err(io::Error::new(io::ErrorKind::Interrupted, "interrupted")),
3 => Read::read(&mut b"def".as_slice(), buf),
_ => Ok(0),
}
}
}
let mut base_reader = InterruptedReader { init: false };
let mut reader = decompress(&mut base_reader, Some(CompressionType::Gzip));
let mut buffer = Vec::new();
reader.read_to_end(&mut buffer).unwrap();
assert_eq!(buffer, b"");
for compression_type in [
None,
Some(CompressionType::Brotli),
Some(CompressionType::Deflate),
Some(CompressionType::Gzip),
Some(CompressionType::Zstd),
] {
let mut base_reader = InterruptedReader { step: 0 };
let mut reader = decompress(&mut base_reader, compression_type);
let mut buffer = Vec::with_capacity(16);
let res = reader.read_to_end(&mut buffer);
if compression_type.is_none() {
res.unwrap();
assert_eq!(buffer, b"abcdef");
} else {
res.unwrap_err();
}
}
}
#[test]
fn empty_inputs_do_not_cause_errors() {
for compression_type in [
None,
Some(CompressionType::Brotli),
Some(CompressionType::Deflate),
Some(CompressionType::Gzip),
Some(CompressionType::Zstd),
] {
let mut input: &[u8] = b"";
let mut reader = decompress(&mut input, compression_type);
let mut buf = Vec::new();
reader.read_to_end(&mut buf).unwrap();
assert_eq!(buf, b"");
// Must accept repeated read attempts after EOF (this happens with --stream)
for _ in 0..10 {
reader.read_to_end(&mut buf).unwrap();
assert_eq!(buf, b"");
}
}
}
#[test]
fn read_errors_keep_their_context() {
#[derive(Debug)]
struct SpecialErr;
impl std::fmt::Display for SpecialErr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
impl std::error::Error for SpecialErr {}
struct SadReader;
impl Read for SadReader {
fn read(&mut self, _buf: &mut [u8]) -> io::Result<usize> {
Err(io::Error::new(io::ErrorKind::WouldBlock, SpecialErr))
}
}
for compression_type in [
None,
Some(CompressionType::Brotli),
Some(CompressionType::Deflate),
Some(CompressionType::Gzip),
Some(CompressionType::Zstd),
] {
let mut input = SadReader;
let mut reader = decompress(&mut input, compression_type);
let mut buf = Vec::new();
let err = reader.read_to_end(&mut buf).unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
err.get_ref().unwrap().downcast_ref::<SpecialErr>().unwrap();
}
}
#[test]
fn true_decode_errors_are_preserved() {
for compression_type in [
CompressionType::Brotli,
CompressionType::Deflate,
CompressionType::Gzip,
CompressionType::Zstd,
] {
let mut input: &[u8] = b"bad";
let mut reader = decompress(&mut input, Some(compression_type));
let mut buf = Vec::new();
let err = reader.read_to_end(&mut buf).unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
let decode_err = err
.get_ref()
.unwrap()
.downcast_ref::<DecodeError>()
.unwrap();
let real_err = decode_err.source().unwrap();
let real_err = real_err.downcast_ref::<io::Error>().unwrap();
// All four decoders make a different choice here...
// Still the easiest way to check that we're preserving the error
let expected_kind = match compression_type {
CompressionType::Gzip => io::ErrorKind::UnexpectedEof,
CompressionType::Deflate => io::ErrorKind::InvalidInput,
CompressionType::Brotli => io::ErrorKind::InvalidData,
CompressionType::Zstd => io::ErrorKind::Other,
};
assert_eq!(real_err.kind(), expected_kind);
}
}
}

View File

@ -44,6 +44,17 @@ struct BinaryGuard<'a, T: Read> {
checked: bool,
}
#[derive(Debug)]
struct FoundBinaryData;
impl std::fmt::Display for FoundBinaryData {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("binary data not shown in terminal")
}
}
impl std::error::Error for FoundBinaryData {}
impl<'a, T: Read> BinaryGuard<'a, T> {
fn new(reader: &'a mut T, checked: bool) -> Self {
Self {
@ -78,10 +89,7 @@ impl<'a, T: Read> BinaryGuard<'a, T> {
Err(e) => return Err(e),
};
if self.checked && buf.contains(&b'\0') {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Found binary data",
));
return Err(io::Error::new(io::ErrorKind::InvalidData, FoundBinaryData));
} else if buf.is_empty() {
if self.buffer.is_empty() {
return Ok(None);
@ -454,7 +462,7 @@ impl Printer {
Ok(_) => {
self.buffer.print("\n")?;
}
Err(err) if err.kind() == io::ErrorKind::InvalidData => {
Err(err) if err.get_ref().is_some_and(|err| err.is::<FoundBinaryData>()) => {
self.buffer.print(BINARY_SUPPRESSOR)?;
}
Err(err) => return Err(err.into()),

View File

@ -3570,6 +3570,59 @@ fn empty_response_with_content_encoding_and_content_length() {
"#});
}
/// Regression test: this used to crash because ZstdDecoder::new() is fallible
#[test]
fn empty_zstd_response_with_content_encoding_and_content_length() {
let server = server::http(|_req| async move {
hyper::Response::builder()
.header("date", "N/A")
.header("content-encoding", "zstd")
.header("content-length", "100")
.body("".into())
.unwrap()
});
get_command()
.arg("head")
.arg(server.base_url())
.assert()
.stdout(indoc! {r#"
HTTP/1.1 200 OK
Content-Encoding: zstd
Content-Length: 100
Date: N/A
"#});
}
/// After an initial fix this scenario still crashed
#[test]
fn streaming_empty_zstd_response_with_content_encoding_and_content_length() {
let server = server::http(|_req| async move {
hyper::Response::builder()
.header("date", "N/A")
.header("content-encoding", "zstd")
.header("content-length", "100")
.body("".into())
.unwrap()
});
get_command()
.arg("--stream")
.arg("head")
.arg(server.base_url())
.assert()
.stdout(indoc! {r#"
HTTP/1.1 200 OK
Content-Encoding: zstd
Content-Length: 100
Date: N/A
"#});
}
#[test]
fn response_meta() {
let server = server::http(|_req| async move {