Python bindings to oxyroot. Makes reading .root files blazing fast 馃殌
at dev 16 kB view raw
1use ::oxyroot::{Named, RootFile}; 2use numpy::IntoPyArray; 3use pyo3::{exceptions::PyValueError, prelude::*, IntoPyObjectExt}; 4use std::fs::File; 5use std::path::Path; 6use std::sync::Arc; 7 8use arrow::array::{ 9 ArrayRef, Float32Array, Float64Array, Int32Array, Int64Array, StringArray, UInt32Array, 10 UInt64Array, 11}; 12use arrow::datatypes::{DataType, Field, Schema}; 13use arrow::record_batch::RecordBatch; 14use once_cell::sync::Lazy; 15use parking_lot::Mutex; 16use parquet::arrow::ArrowWriter; 17use parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel}; 18use parquet::file::properties::WriterProperties; 19use polars::functions::concat_df_diagonal; 20use polars::prelude::*; 21use pyo3_polars::PyDataFrame; 22use rayon::prelude::*; 23 24static POOL: Lazy<Mutex<rayon::ThreadPool>> = Lazy::new(|| { 25 let num_threads = std::cmp::max(1, num_cpus::get() / 2); 26 let pool = rayon::ThreadPoolBuilder::new() 27 .num_threads(num_threads) 28 .build() 29 .unwrap(); 30 Mutex::new(pool) 31}); 32 33#[pyfunction] 34fn set_num_threads(num_threads: usize) -> PyResult<()> { 35 let pool = rayon::ThreadPoolBuilder::new() 36 .num_threads(num_threads) 37 .build() 38 .map_err(|e| PyValueError::new_err(e.to_string()))?; 39 *POOL.lock() = pool; 40 Ok(()) 41} 42 43#[pyclass(name = "RootFile")] 44struct PyRootFile { 45 #[pyo3(get)] 46 path: String, 47} 48 49#[pyclass(name = "Tree")] 50struct PyTree { 51 #[pyo3(get)] 52 path: String, 53 #[pyo3(get)] 54 name: String, 55} 56 57#[pyclass(name = "Branch")] 58struct PyBranch { 59 #[pyo3(get)] 60 path: String, 61 #[pyo3(get)] 62 tree_name: String, 63 #[pyo3(get)] 64 name: String, 65} 66 67fn tree_to_dataframe( 68 tree: &::oxyroot::ReaderTree, 69 columns: Option<Vec<String>>, 70 ignore_columns: Option<Vec<String>>, 71) -> PyResult<DataFrame> { 72 let mut branches_to_save = if let Some(columns) = columns { 73 columns 74 } else { 75 tree.branches().map(|b| b.name().to_string()).collect() 76 }; 77 78 if let Some(ignore_columns) = ignore_columns { 79 branches_to_save.retain(|c| !ignore_columns.contains(c)); 80 } 81 82 let mut series_vec = Vec::new(); 83 84 for branch_name in branches_to_save { 85 let branch = match tree.branch(&branch_name) { 86 Some(branch) => branch, 87 None => { 88 println!("Branch '{}' not found, skipping", branch_name); 89 continue; 90 } 91 }; 92 93 let series = match branch.item_type_name().as_str() { 94 "float" => { 95 let data = branch.as_iter::<f32>().unwrap().collect::<Vec<_>>(); 96 Series::new((&branch_name).into(), data) 97 } 98 "double" => { 99 let data = branch.as_iter::<f64>().unwrap().collect::<Vec<_>>(); 100 Series::new((&branch_name).into(), data) 101 } 102 "int32_t" => { 103 let data = branch.as_iter::<i32>().unwrap().collect::<Vec<_>>(); 104 Series::new((&branch_name).into(), data) 105 } 106 "int64_t" => { 107 let data = branch.as_iter::<i64>().unwrap().collect::<Vec<_>>(); 108 Series::new((&branch_name).into(), data) 109 } 110 "uint32_t" => { 111 let data = branch.as_iter::<u32>().unwrap().collect::<Vec<_>>(); 112 Series::new((&branch_name).into(), data) 113 } 114 "uint64_t" => { 115 let data = branch.as_iter::<u64>().unwrap().collect::<Vec<_>>(); 116 Series::new((&branch_name).into(), data) 117 } 118 "string" => { 119 let data = branch.as_iter::<String>().unwrap().collect::<Vec<_>>(); 120 Series::new((&branch_name).into(), data) 121 } 122 other => { 123 println!("Unsupported branch type: {}, skipping", other); 124 continue; 125 } 126 }; 127 series_vec.push(series); 128 } 129 130 DataFrame::new(series_vec.into_iter().map(|s| s.into()).collect()) 131 .map_err(|e| PyValueError::new_err(e.to_string())) 132} 133 134#[pymethods] 135impl PyRootFile { 136 #[new] 137 fn new(path: String) -> Self { 138 PyRootFile { path } 139 } 140 141 fn keys(&self) -> PyResult<Vec<String>> { 142 let file = RootFile::open(&self.path).map_err(|e| PyValueError::new_err(e.to_string()))?; 143 Ok(file 144 .keys() 145 .into_iter() 146 .map(|k| k.name().to_string()) 147 .collect()) 148 } 149 150 fn __getitem__(&self, name: &str) -> PyResult<PyTree> { 151 Ok(PyTree { 152 path: self.path.clone(), 153 name: name.to_string(), 154 }) 155 } 156} 157 158#[pymethods] 159impl PyTree { 160 fn branches(&self) -> PyResult<Vec<String>> { 161 let mut file = 162 RootFile::open(&self.path).map_err(|e| PyValueError::new_err(e.to_string()))?; 163 let tree = file 164 .get_tree(&self.name) 165 .map_err(|e| PyValueError::new_err(e.to_string()))?; 166 Ok(tree.branches().map(|b| b.name().to_string()).collect()) 167 } 168 169 fn __getitem__(&self, name: &str) -> PyResult<PyBranch> { 170 Ok(PyBranch { 171 path: self.path.clone(), 172 tree_name: self.name.clone(), 173 name: name.to_string(), 174 }) 175 } 176 177 fn __iter__(slf: PyRef<Self>) -> PyResult<Py<PyBranchIterator>> { 178 let branches = slf.branches()?; 179 Py::new( 180 slf.py(), 181 PyBranchIterator { 182 path: slf.path.clone(), 183 tree_name: slf.name.clone(), 184 branches: branches.into_iter(), 185 }, 186 ) 187 } 188 189 #[pyo3(signature = (columns = None, ignore_columns = None))] 190 fn arrays( 191 &self, 192 columns: Option<Vec<String>>, 193 ignore_columns: Option<Vec<String>>, 194 ) -> PyResult<PyDataFrame> { 195 let mut file = 196 RootFile::open(&self.path).map_err(|e| PyValueError::new_err(e.to_string()))?; 197 let tree = file 198 .get_tree(&self.name) 199 .map_err(|e| PyValueError::new_err(e.to_string()))?; 200 let df = tree_to_dataframe(&tree, columns, ignore_columns)?; 201 Ok(PyDataFrame(df)) 202 } 203 204 #[pyo3(signature = (output_file, overwrite = false, compression = "snappy", columns = None))] 205 fn to_parquet( 206 &self, 207 output_file: String, 208 overwrite: bool, 209 compression: &str, 210 columns: Option<Vec<String>>, 211 ) -> PyResult<()> { 212 if !overwrite && Path::new(&output_file).exists() { 213 return Err(PyValueError::new_err("File exists, use overwrite=True")); 214 } 215 216 let compression = match compression { 217 "snappy" => Compression::SNAPPY, 218 "uncompressed" => Compression::UNCOMPRESSED, 219 "gzip" => Compression::GZIP(GzipLevel::default()), 220 "lzo" => Compression::LZO, 221 "brotli" => Compression::BROTLI(BrotliLevel::default()), 222 "lz4" => Compression::LZ4, 223 "zstd" => Compression::ZSTD(ZstdLevel::default()), 224 _ => return Err(PyValueError::new_err("Invalid compression type")), 225 }; 226 227 let mut file = 228 RootFile::open(&self.path).map_err(|e| PyValueError::new_err(e.to_string()))?; 229 let tree = file 230 .get_tree(&self.name) 231 .map_err(|e| PyValueError::new_err(e.to_string()))?; 232 233 let mut fields = Vec::new(); 234 let mut arrays = Vec::new(); 235 236 let branches_to_save = if let Some(columns) = columns { 237 columns 238 } else { 239 tree.branches().map(|b| b.name().to_string()).collect() 240 }; 241 242 for branch_name in branches_to_save { 243 let branch = match tree.branch(&branch_name) { 244 Some(branch) => branch, 245 None => { 246 println!("Branch '{}' not found, skipping", branch_name); 247 continue; 248 } 249 }; 250 251 let (field, array) = match branch.item_type_name().as_str() { 252 "float" => { 253 let data = branch.as_iter::<f32>().unwrap().collect::<Vec<_>>(); 254 let array: ArrayRef = Arc::new(Float32Array::from(data)); 255 (Field::new(&branch_name, DataType::Float32, false), array) 256 } 257 "double" => { 258 let data = branch.as_iter::<f64>().unwrap().collect::<Vec<_>>(); 259 let array: ArrayRef = Arc::new(Float64Array::from(data)); 260 (Field::new(&branch_name, DataType::Float64, false), array) 261 } 262 "int32_t" => { 263 let data = branch.as_iter::<i32>().unwrap().collect::<Vec<_>>(); 264 let array: ArrayRef = Arc::new(Int32Array::from(data)); 265 (Field::new(&branch_name, DataType::Int32, false), array) 266 } 267 "int64_t" => { 268 let data = branch.as_iter::<i64>().unwrap().collect::<Vec<_>>(); 269 let array: ArrayRef = Arc::new(Int64Array::from(data)); 270 (Field::new(&branch_name, DataType::Int64, false), array) 271 } 272 "uint32_t" => { 273 let data = branch.as_iter::<u32>().unwrap().collect::<Vec<_>>(); 274 let array: ArrayRef = Arc::new(UInt32Array::from(data)); 275 (Field::new(&branch_name, DataType::UInt32, false), array) 276 } 277 "uint64_t" => { 278 let data = branch.as_iter::<u64>().unwrap().collect::<Vec<_>>(); 279 let array: ArrayRef = Arc::new(UInt64Array::from(data)); 280 (Field::new(&branch_name, DataType::UInt64, false), array) 281 } 282 "string" => { 283 let data = branch.as_iter::<String>().unwrap().collect::<Vec<_>>(); 284 let array: ArrayRef = Arc::new(StringArray::from(data)); 285 (Field::new(&branch_name, DataType::Utf8, false), array) 286 } 287 other => { 288 println!("Unsupported branch type: {}, skipping", other); 289 continue; 290 } 291 }; 292 fields.push(field); 293 arrays.push(array); 294 } 295 296 let schema = Arc::new(Schema::new(fields)); 297 let props = WriterProperties::builder() 298 .set_compression(compression) 299 .build(); 300 let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap(); 301 302 let file = File::create(output_file)?; 303 let mut writer = ArrowWriter::try_new(file, schema, Some(props)) 304 .map_err(|e| PyValueError::new_err(e.to_string()))?; 305 writer 306 .write(&batch) 307 .map_err(|e| PyValueError::new_err(e.to_string()))?; 308 writer 309 .close() 310 .map_err(|e| PyValueError::new_err(e.to_string()))?; 311 312 Ok(()) 313 } 314} 315 316#[pyclass] 317struct PyBranchIterator { 318 path: String, 319 tree_name: String, 320 branches: std::vec::IntoIter<String>, 321} 322 323#[pymethods] 324impl PyBranchIterator { 325 fn __iter__(slf: PyRef<Self>) -> PyRef<Self> { 326 slf 327 } 328 329 fn __next__(&mut self) -> Option<PyBranch> { 330 self.branches.next().map(|name| PyBranch { 331 path: self.path.clone(), 332 tree_name: self.tree_name.clone(), 333 name, 334 }) 335 } 336} 337 338#[pymethods] 339impl PyBranch { 340 fn array(&self, py: Python) -> PyResult<Py<PyAny>> { 341 let mut file = 342 RootFile::open(&self.path).map_err(|e| PyValueError::new_err(e.to_string()))?; 343 let tree = file 344 .get_tree(&self.tree_name) 345 .map_err(|e| PyValueError::new_err(e.to_string()))?; 346 let branch = tree 347 .branch(&self.name) 348 .ok_or_else(|| PyValueError::new_err("Branch not found"))?; 349 350 match branch.item_type_name().as_str() { 351 "float" => { 352 let data = branch 353 .as_iter::<f32>() 354 .map_err(|e| PyValueError::new_err(e.to_string()))? 355 .collect::<Vec<_>>(); 356 Ok(data.into_pyarray(py).into()) 357 } 358 "double" => { 359 let data = branch 360 .as_iter::<f64>() 361 .map_err(|e| PyValueError::new_err(e.to_string()))? 362 .collect::<Vec<_>>(); 363 Ok(data.into_pyarray(py).into()) 364 } 365 "int32_t" => { 366 let data = branch 367 .as_iter::<i32>() 368 .map_err(|e| PyValueError::new_err(e.to_string()))? 369 .collect::<Vec<_>>(); 370 Ok(data.into_pyarray(py).into()) 371 } 372 "int64_t" => { 373 let data = branch 374 .as_iter::<i64>() 375 .map_err(|e| PyValueError::new_err(e.to_string()))? 376 .collect::<Vec<_>>(); 377 Ok(data.into_pyarray(py).into()) 378 } 379 "uint32_t" => { 380 let data = branch 381 .as_iter::<u32>() 382 .map_err(|e| PyValueError::new_err(e.to_string()))? 383 .collect::<Vec<_>>(); 384 Ok(data.into_pyarray(py).into()) 385 } 386 "uint64_t" => { 387 let data = branch 388 .as_iter::<u64>() 389 .map_err(|e| PyValueError::new_err(e.to_string()))? 390 .collect::<Vec<_>>(); 391 Ok(data.into_pyarray(py).into()) 392 } 393 "string" => { 394 let data = branch 395 .as_iter::<String>() 396 .map_err(|e| PyValueError::new_err(e.to_string()))? 397 .collect::<Vec<_>>(); 398 Ok(data.into_py_any(py).unwrap()) 399 } 400 other => Err(PyValueError::new_err(format!( 401 "Unsupported branch type: {}", 402 other 403 ))), 404 } 405 } 406 407 #[getter] 408 fn typename(&self) -> PyResult<String> { 409 let mut file = 410 RootFile::open(&self.path).map_err(|e| PyValueError::new_err(e.to_string()))?; 411 let tree = file 412 .get_tree(&self.tree_name) 413 .map_err(|e| PyValueError::new_err(e.to_string()))?; 414 let branch = tree 415 .branch(&self.name) 416 .ok_or_else(|| PyValueError::new_err("Branch not found"))?; 417 Ok(branch.item_type_name()) 418 } 419} 420 421#[pyfunction] 422fn open(path: String) -> PyResult<PyRootFile> { 423 Ok(PyRootFile::new(path)) 424} 425 426#[pyfunction] 427fn version() -> PyResult<String> { 428 Ok(env!("CARGO_PKG_VERSION").to_string()) 429} 430 431#[pyfunction] 432#[pyo3(signature = (paths, tree_name, columns = None, ignore_columns = None))] 433fn concat_trees( 434 paths: Vec<String>, 435 tree_name: String, 436 columns: Option<Vec<String>>, 437 ignore_columns: Option<Vec<String>>, 438) -> PyResult<PyDataFrame> { 439 let mut all_paths = Vec::new(); 440 for path in paths { 441 for entry in glob::glob(&path).map_err(|e| PyValueError::new_err(e.to_string()))? { 442 match entry { 443 Ok(path) => { 444 all_paths.push(path.to_str().unwrap().to_string()); 445 } 446 Err(e) => return Err(PyValueError::new_err(e.to_string())), 447 } 448 } 449 } 450 451 let pool = POOL.lock(); 452 let dfs: Vec<DataFrame> = pool.install(|| { 453 all_paths 454 .par_iter() 455 .map(|path| { 456 let mut file = 457 RootFile::open(path).map_err(|e| PyValueError::new_err(e.to_string()))?; 458 let tree = file 459 .get_tree(&tree_name) 460 .map_err(|e| PyValueError::new_err(e.to_string()))?; 461 tree_to_dataframe(&tree, columns.clone(), ignore_columns.clone()) 462 }) 463 .filter_map(Result::ok) 464 .collect() 465 }); 466 467 if dfs.is_empty() { 468 return Ok(PyDataFrame(DataFrame::default())); 469 } 470 471 let combined_df = concat_df_diagonal(&dfs).map_err(|e| PyValueError::new_err(e.to_string()))?; 472 473 Ok(PyDataFrame(combined_df)) 474} 475 476/// A Python module to read root files, implemented in Rust. 477#[pymodule] 478fn oxyroot(m: &Bound<'_, PyModule>) -> PyResult<()> { 479 m.add_function(wrap_pyfunction!(version, m)?)?; 480 m.add_function(wrap_pyfunction!(open, m)?)?; 481 m.add_function(wrap_pyfunction!(concat_trees, m)?)?; 482 m.add_function(wrap_pyfunction!(set_num_threads, m)?)?; 483 m.add_class::<PyRootFile>()?; 484 m.add_class::<PyTree>()?; 485 m.add_class::<PyBranch>()?; 486 m.add_class::<PyBranchIterator>()?; 487 Ok(()) 488}