hopr_chain_indexer/snapshot/
extract.rs1use std::{
7 fs,
8 fs::File,
9 path::{Component::ParentDir, Path},
10};
11
12use async_compression::futures::bufread::XzDecoder;
13use async_tar::Archive;
14use futures_util::{
15 StreamExt,
16 io::{AllowStdIo, BufReader as FuturesBufReader},
17};
18use tracing::{debug, error, info};
19
20use crate::snapshot::error::{SnapshotError, SnapshotResult};
21
22pub struct SnapshotExtractor {
27 expected_files: Vec<String>,
29}
30
31impl SnapshotExtractor {
32 pub fn new() -> Self {
39 Self {
40 expected_files: vec![
41 "hopr_logs.db".to_string(),
42 "hopr_logs.db-wal".to_string(),
43 "hopr_logs.db-shm".to_string(),
44 ],
45 }
46 }
47
48 pub async fn extract_snapshot(&self, archive_path: &Path, target_dir: &Path) -> SnapshotResult<Vec<String>> {
73 info!(from = %archive_path.display(), to = %target_dir.display(), "Extracting snapshot");
74
75 fs::create_dir_all(target_dir)?;
77
78 let extracted_files = self.extract_tar_xz(archive_path, target_dir).await?;
79
80 info!(nr_of_files = extracted_files.len(), "Extracted snapshot files");
81 Ok(extracted_files)
82 }
83
84 async fn extract_tar_xz(&self, archive_path: &Path, target_dir: &Path) -> SnapshotResult<Vec<String>> {
86 let file = File::open(archive_path).map_err(SnapshotError::Io)?;
88 let file_reader = AllowStdIo::new(file);
89
90 let buf_reader = FuturesBufReader::new(file_reader);
92 let decoder = XzDecoder::new(buf_reader);
93 let archive = Archive::new(decoder);
94
95 let mut extracted_files = Vec::new();
96 let mut entries = archive.entries().map_err(SnapshotError::Io)?;
97
98 while let Some(entry_result) = entries.next().await {
99 let mut entry = entry_result.map_err(SnapshotError::Io)?;
100 let path_buf = entry.path().map_err(SnapshotError::Io)?.to_path_buf();
101
102 if !path_is_safe(path_buf.as_path().into()) {
105 return Err(SnapshotError::InvalidFormat(
106 "Archive contains parent directory references".to_string(),
107 ));
108 }
109
110 let filename = path_buf
112 .file_name()
113 .and_then(|s| s.to_str())
114 .ok_or_else(|| SnapshotError::InvalidFormat("Invalid filename".to_string()))?;
115
116 if self.expected_files.iter().any(|f| f == filename) {
118 entry.unpack_in(target_dir).await.map_err(SnapshotError::Io)?;
120 extracted_files.push(filename.to_string());
121
122 debug!(%filename, "Extracted file");
123 } else {
124 error!(%filename, "Skipping unexpected file in archive");
125 }
126 }
127
128 if !extracted_files.contains(&"hopr_logs.db".to_string()) {
130 return Err(SnapshotError::InvalidFormat(
131 "Archive does not contain hopr_logs.db".to_string(),
132 ));
133 }
134
135 Ok(extracted_files)
136 }
137
138 pub async fn validate_archive(&self, archive_path: &Path) -> SnapshotResult<Vec<String>> {
140 self.list_archive_contents(archive_path).await
141 }
142
143 async fn list_archive_contents(&self, archive_path: &Path) -> SnapshotResult<Vec<String>> {
145 let file = File::open(archive_path).map_err(SnapshotError::Io)?;
147 let file_reader = AllowStdIo::new(file);
148
149 let buf_reader = FuturesBufReader::new(file_reader);
151 let decoder = XzDecoder::new(buf_reader);
152 let archive = Archive::new(decoder);
153
154 let mut files = Vec::new();
155 let mut entries = archive.entries().map_err(SnapshotError::Io)?;
156
157 while let Some(entry_result) = entries.next().await {
158 let entry = entry_result.map_err(SnapshotError::Io)?;
159 let path = entry.path().map_err(SnapshotError::Io)?;
160
161 if let Some(filename) = path.file_name().and_then(|s| s.to_str()) {
162 files.push(filename.to_string());
163 }
164 }
165
166 Ok(files)
167 }
168}
169
170fn path_is_safe(path: &Path) -> bool {
172 !path.components().any(|c| c == ParentDir)
173}
174
175impl Default for SnapshotExtractor {
176 fn default() -> Self {
177 Self::new()
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use tempfile::TempDir;
184
185 use super::*;
186 use crate::snapshot::test_utils::create_test_archive;
187
188 #[tokio::test]
189 async fn test_extraction() {
190 let temp_dir = TempDir::new().unwrap();
191 let extractor = SnapshotExtractor::new();
192
193 let archive_path = create_test_archive(&temp_dir, None).await.unwrap();
195
196 let extract_dir = temp_dir.path().join("extracted");
198 let result = extractor.extract_snapshot(&archive_path, &extract_dir).await;
199
200 assert!(result.is_ok(), "Extraction should succeed");
201 let files = result.unwrap();
202 assert!(files.contains(&"hopr_logs.db".to_string()));
203 assert!(extract_dir.join("hopr_logs.db").exists());
204 }
205
206 #[tokio::test]
207 async fn test_archive_security_validation() {
208 let temp_dir = TempDir::new().unwrap();
209 let extractor = SnapshotExtractor::new();
210
211 let archive_path = create_test_archive(&temp_dir, None).await.unwrap();
213
214 let extract_dir = temp_dir.path().join("extract");
215
216 assert!(!extract_dir.parent().unwrap().join("hopr_logs.db").exists());
218
219 let result = extractor.extract_snapshot(&archive_path, &extract_dir).await;
220
221 assert!(result.is_ok());
222
223 let extracted_files = result.unwrap();
225 assert!(extracted_files.contains(&"hopr_logs.db".to_string()));
226 assert!(!extract_dir.parent().unwrap().join("hopr_logs.db").exists());
227 }
228
229 #[tokio::test]
230 async fn test_invalid_archive() {
231 let temp_dir = TempDir::new().unwrap();
232 let extractor = SnapshotExtractor::new();
233
234 let archive_path = temp_dir.path().join("invalid.tar.xz");
236 fs::write(&archive_path, "not a valid archive").unwrap();
237
238 let extract_dir = temp_dir.path().join("extracted");
239 let result = extractor.extract_snapshot(&archive_path, &extract_dir).await;
240
241 assert!(result.is_err(), "Extraction should fail for invalid archive");
242 }
243
244 #[test_log::test(tokio::test)]
245 async fn test_path_traversal_protection() {
246 assert!(path_is_safe(Path::new("good.db")));
247
248 assert!(!path_is_safe(Path::new("../malicious.db")));
249 assert!(!path_is_safe(Path::new("../../malicious.db")));
250 assert!(!path_is_safe(Path::new("../back/../malicious.db")));
251 assert!(!path_is_safe(Path::new("forward/../../malicious.db")));
252 }
253}