#![allow(clippy::module_name_repetitions)]
use std::ops::RangeInclusive;
use bytes::{Buf, Bytes, BytesMut};
use tokio_util::codec;
use crate::{packet_line::PktLine, Error};
const ALLOWED_PACKET_LENGTH: RangeInclusive<usize> = 4..=65520;
pub struct Encoder;
impl codec::Encoder<PktLine<'_>> for Encoder {
type Error = Error;
fn encode(&mut self, item: PktLine<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
item.encode_to(dst)?;
Ok(())
}
}
#[derive(Debug, Default, PartialEq, Eq)]
pub struct GitCommand {
pub command: Bytes,
pub metadata: Vec<Bytes>,
}
#[derive(Default)]
pub struct GitCodec {
command: GitCommand,
}
impl codec::Decoder for GitCodec {
type Item = GitCommand;
type Error = Error;
#[cfg_attr(feature = "tracing", tracing::instrument(skip(self, src), err))]
fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
loop {
if src.len() < 4 {
return Ok(None);
}
let mut length_bytes = [0_u8; 4];
length_bytes.copy_from_slice(&src[..4]);
let length = u16::from_str_radix(
std::str::from_utf8(&length_bytes).map_err(Error::ParseLengthBytes)?,
16,
)
.map_err(Error::ParseLengthAsHex)? as usize;
if length == 0 {
src.advance(4);
return Ok(Some(std::mem::take(&mut self.command)));
} else if length == 1 || length == 2 {
src.advance(4);
continue;
} else if !ALLOWED_PACKET_LENGTH.contains(&length) {
return Err(Error::PacketLengthExceedsSpec(
ALLOWED_PACKET_LENGTH,
length,
));
}
if src.len() < length {
src.reserve(length - src.len());
return Ok(None);
}
let mut data = src.split_to(length).freeze();
data.advance(4);
if data.ends_with(b"\n") {
data.truncate(data.len() - 1);
}
if self.command.command.is_empty() {
self.command.command = data;
} else {
self.command.metadata.push(data);
}
}
}
}
#[cfg(test)]
mod test {
use crate::PktLine;
use bytes::{Bytes, BytesMut};
use std::fmt::Write;
use tokio_util::codec::{Decoder, Encoder};
#[test]
fn encode() {
let mut bytes = BytesMut::new();
super::Encoder
.encode(PktLine::Data(&[1, 2, 3, 4]), &mut bytes)
.unwrap();
assert_eq!(bytes.to_vec(), b"0008\x01\x02\x03\x04");
}
#[test]
fn decode() {
let mut codec = super::GitCodec::default();
let mut bytes = BytesMut::new();
bytes.write_str("0015agent=git/2.32.0").unwrap();
let res = codec.decode(&mut bytes).unwrap();
assert_eq!(res, None);
bytes.write_char('\n').unwrap();
let res = codec.decode(&mut bytes).unwrap();
assert_eq!(res, None);
bytes.write_str("0000").unwrap();
let res = codec.decode(&mut bytes).unwrap();
assert_eq!(
res,
Some(super::GitCommand {
command: Bytes::from_static(b"agent=git/2.32.0"),
metadata: vec![],
})
);
bytes.write_str("0000").unwrap();
let res = codec.decode(&mut bytes).unwrap();
assert_eq!(
res,
Some(super::GitCommand {
command: Bytes::new(),
metadata: vec![],
})
);
bytes.write_str("0002").unwrap();
bytes.write_str("0005a").unwrap();
bytes.write_str("0001").unwrap();
bytes.write_str("0005b").unwrap();
bytes.write_str("0000").unwrap();
let res = codec.decode(&mut bytes).unwrap();
assert_eq!(
res,
Some(super::GitCommand {
command: Bytes::from_static(b"a"),
metadata: vec![Bytes::from_static(b"b")],
})
);
}
}