Python bindings to oxyroot. Makes reading .root files blazing fast 馃殌
1use ::oxyroot::{Named, RootFile};
2use numpy::IntoPyArray;
3use pyo3::{exceptions::PyValueError, prelude::*, IntoPyObjectExt};
4use std::fs::File;
5use std::path::Path;
6use std::sync::Arc;
7
8use arrow::array::{
9 ArrayRef, Float32Array, Float64Array, Int32Array, Int64Array, StringArray, UInt32Array,
10 UInt64Array,
11};
12use arrow::datatypes::{DataType, Field, Schema};
13use arrow::record_batch::RecordBatch;
14use once_cell::sync::Lazy;
15use parking_lot::Mutex;
16use parquet::arrow::ArrowWriter;
17use parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
18use parquet::file::properties::WriterProperties;
19use polars::functions::concat_df_diagonal;
20use polars::prelude::*;
21use pyo3_polars::PyDataFrame;
22use rayon::prelude::*;
23
24static POOL: Lazy<Mutex<rayon::ThreadPool>> = Lazy::new(|| {
25 let num_threads = std::cmp::max(1, num_cpus::get() / 2);
26 let pool = rayon::ThreadPoolBuilder::new()
27 .num_threads(num_threads)
28 .build()
29 .unwrap();
30 Mutex::new(pool)
31});
32
33#[pyfunction]
34fn set_num_threads(num_threads: usize) -> PyResult<()> {
35 let pool = rayon::ThreadPoolBuilder::new()
36 .num_threads(num_threads)
37 .build()
38 .map_err(|e| PyValueError::new_err(e.to_string()))?;
39 *POOL.lock() = pool;
40 Ok(())
41}
42
43#[pyclass(name = "RootFile")]
44struct PyRootFile {
45 #[pyo3(get)]
46 path: String,
47}
48
49#[pyclass(name = "Tree")]
50struct PyTree {
51 #[pyo3(get)]
52 path: String,
53 #[pyo3(get)]
54 name: String,
55}
56
57#[pyclass(name = "Branch")]
58struct PyBranch {
59 #[pyo3(get)]
60 path: String,
61 #[pyo3(get)]
62 tree_name: String,
63 #[pyo3(get)]
64 name: String,
65}
66
67fn tree_to_dataframe(
68 tree: &::oxyroot::ReaderTree,
69 columns: Option<Vec<String>>,
70 ignore_columns: Option<Vec<String>>,
71) -> PyResult<DataFrame> {
72 let mut branches_to_save = if let Some(columns) = columns {
73 columns
74 } else {
75 tree.branches().map(|b| b.name().to_string()).collect()
76 };
77
78 if let Some(ignore_columns) = ignore_columns {
79 branches_to_save.retain(|c| !ignore_columns.contains(c));
80 }
81
82 let mut series_vec = Vec::new();
83
84 for branch_name in branches_to_save {
85 let branch = match tree.branch(&branch_name) {
86 Some(branch) => branch,
87 None => {
88 println!("Branch '{}' not found, skipping", branch_name);
89 continue;
90 }
91 };
92
93 let series = match branch.item_type_name().as_str() {
94 "float" => {
95 let data = branch.as_iter::<f32>().unwrap().collect::<Vec<_>>();
96 Series::new((&branch_name).into(), data)
97 }
98 "double" => {
99 let data = branch.as_iter::<f64>().unwrap().collect::<Vec<_>>();
100 Series::new((&branch_name).into(), data)
101 }
102 "int32_t" => {
103 let data = branch.as_iter::<i32>().unwrap().collect::<Vec<_>>();
104 Series::new((&branch_name).into(), data)
105 }
106 "int64_t" => {
107 let data = branch.as_iter::<i64>().unwrap().collect::<Vec<_>>();
108 Series::new((&branch_name).into(), data)
109 }
110 "uint32_t" => {
111 let data = branch.as_iter::<u32>().unwrap().collect::<Vec<_>>();
112 Series::new((&branch_name).into(), data)
113 }
114 "uint64_t" => {
115 let data = branch.as_iter::<u64>().unwrap().collect::<Vec<_>>();
116 Series::new((&branch_name).into(), data)
117 }
118 "string" => {
119 let data = branch.as_iter::<String>().unwrap().collect::<Vec<_>>();
120 Series::new((&branch_name).into(), data)
121 }
122 other => {
123 println!("Unsupported branch type: {}, skipping", other);
124 continue;
125 }
126 };
127 series_vec.push(series);
128 }
129
130 DataFrame::new(series_vec.into_iter().map(|s| s.into()).collect())
131 .map_err(|e| PyValueError::new_err(e.to_string()))
132}
133
134#[pymethods]
135impl PyRootFile {
136 #[new]
137 fn new(path: String) -> Self {
138 PyRootFile { path }
139 }
140
141 fn keys(&self) -> PyResult<Vec<String>> {
142 let file = RootFile::open(&self.path).map_err(|e| PyValueError::new_err(e.to_string()))?;
143 Ok(file
144 .keys()
145 .into_iter()
146 .map(|k| k.name().to_string())
147 .collect())
148 }
149
150 fn __getitem__(&self, name: &str) -> PyResult<PyTree> {
151 Ok(PyTree {
152 path: self.path.clone(),
153 name: name.to_string(),
154 })
155 }
156}
157
158#[pymethods]
159impl PyTree {
160 fn branches(&self) -> PyResult<Vec<String>> {
161 let mut file =
162 RootFile::open(&self.path).map_err(|e| PyValueError::new_err(e.to_string()))?;
163 let tree = file
164 .get_tree(&self.name)
165 .map_err(|e| PyValueError::new_err(e.to_string()))?;
166 Ok(tree.branches().map(|b| b.name().to_string()).collect())
167 }
168
169 fn __getitem__(&self, name: &str) -> PyResult<PyBranch> {
170 Ok(PyBranch {
171 path: self.path.clone(),
172 tree_name: self.name.clone(),
173 name: name.to_string(),
174 })
175 }
176
177 fn __iter__(slf: PyRef<Self>) -> PyResult<Py<PyBranchIterator>> {
178 let branches = slf.branches()?;
179 Py::new(
180 slf.py(),
181 PyBranchIterator {
182 path: slf.path.clone(),
183 tree_name: slf.name.clone(),
184 branches: branches.into_iter(),
185 },
186 )
187 }
188
189 #[pyo3(signature = (columns = None, ignore_columns = None))]
190 fn arrays(
191 &self,
192 columns: Option<Vec<String>>,
193 ignore_columns: Option<Vec<String>>,
194 ) -> PyResult<PyDataFrame> {
195 let mut file =
196 RootFile::open(&self.path).map_err(|e| PyValueError::new_err(e.to_string()))?;
197 let tree = file
198 .get_tree(&self.name)
199 .map_err(|e| PyValueError::new_err(e.to_string()))?;
200 let df = tree_to_dataframe(&tree, columns, ignore_columns)?;
201 Ok(PyDataFrame(df))
202 }
203
204 #[pyo3(signature = (output_file, overwrite = false, compression = "snappy", columns = None))]
205 fn to_parquet(
206 &self,
207 output_file: String,
208 overwrite: bool,
209 compression: &str,
210 columns: Option<Vec<String>>,
211 ) -> PyResult<()> {
212 if !overwrite && Path::new(&output_file).exists() {
213 return Err(PyValueError::new_err("File exists, use overwrite=True"));
214 }
215
216 let compression = match compression {
217 "snappy" => Compression::SNAPPY,
218 "uncompressed" => Compression::UNCOMPRESSED,
219 "gzip" => Compression::GZIP(GzipLevel::default()),
220 "lzo" => Compression::LZO,
221 "brotli" => Compression::BROTLI(BrotliLevel::default()),
222 "lz4" => Compression::LZ4,
223 "zstd" => Compression::ZSTD(ZstdLevel::default()),
224 _ => return Err(PyValueError::new_err("Invalid compression type")),
225 };
226
227 let mut file =
228 RootFile::open(&self.path).map_err(|e| PyValueError::new_err(e.to_string()))?;
229 let tree = file
230 .get_tree(&self.name)
231 .map_err(|e| PyValueError::new_err(e.to_string()))?;
232
233 let mut fields = Vec::new();
234 let mut arrays = Vec::new();
235
236 let branches_to_save = if let Some(columns) = columns {
237 columns
238 } else {
239 tree.branches().map(|b| b.name().to_string()).collect()
240 };
241
242 for branch_name in branches_to_save {
243 let branch = match tree.branch(&branch_name) {
244 Some(branch) => branch,
245 None => {
246 println!("Branch '{}' not found, skipping", branch_name);
247 continue;
248 }
249 };
250
251 let (field, array) = match branch.item_type_name().as_str() {
252 "float" => {
253 let data = branch.as_iter::<f32>().unwrap().collect::<Vec<_>>();
254 let array: ArrayRef = Arc::new(Float32Array::from(data));
255 (Field::new(&branch_name, DataType::Float32, false), array)
256 }
257 "double" => {
258 let data = branch.as_iter::<f64>().unwrap().collect::<Vec<_>>();
259 let array: ArrayRef = Arc::new(Float64Array::from(data));
260 (Field::new(&branch_name, DataType::Float64, false), array)
261 }
262 "int32_t" => {
263 let data = branch.as_iter::<i32>().unwrap().collect::<Vec<_>>();
264 let array: ArrayRef = Arc::new(Int32Array::from(data));
265 (Field::new(&branch_name, DataType::Int32, false), array)
266 }
267 "int64_t" => {
268 let data = branch.as_iter::<i64>().unwrap().collect::<Vec<_>>();
269 let array: ArrayRef = Arc::new(Int64Array::from(data));
270 (Field::new(&branch_name, DataType::Int64, false), array)
271 }
272 "uint32_t" => {
273 let data = branch.as_iter::<u32>().unwrap().collect::<Vec<_>>();
274 let array: ArrayRef = Arc::new(UInt32Array::from(data));
275 (Field::new(&branch_name, DataType::UInt32, false), array)
276 }
277 "uint64_t" => {
278 let data = branch.as_iter::<u64>().unwrap().collect::<Vec<_>>();
279 let array: ArrayRef = Arc::new(UInt64Array::from(data));
280 (Field::new(&branch_name, DataType::UInt64, false), array)
281 }
282 "string" => {
283 let data = branch.as_iter::<String>().unwrap().collect::<Vec<_>>();
284 let array: ArrayRef = Arc::new(StringArray::from(data));
285 (Field::new(&branch_name, DataType::Utf8, false), array)
286 }
287 other => {
288 println!("Unsupported branch type: {}, skipping", other);
289 continue;
290 }
291 };
292 fields.push(field);
293 arrays.push(array);
294 }
295
296 let schema = Arc::new(Schema::new(fields));
297 let props = WriterProperties::builder()
298 .set_compression(compression)
299 .build();
300 let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap();
301
302 let file = File::create(output_file)?;
303 let mut writer = ArrowWriter::try_new(file, schema, Some(props))
304 .map_err(|e| PyValueError::new_err(e.to_string()))?;
305 writer
306 .write(&batch)
307 .map_err(|e| PyValueError::new_err(e.to_string()))?;
308 writer
309 .close()
310 .map_err(|e| PyValueError::new_err(e.to_string()))?;
311
312 Ok(())
313 }
314}
315
316#[pyclass]
317struct PyBranchIterator {
318 path: String,
319 tree_name: String,
320 branches: std::vec::IntoIter<String>,
321}
322
323#[pymethods]
324impl PyBranchIterator {
325 fn __iter__(slf: PyRef<Self>) -> PyRef<Self> {
326 slf
327 }
328
329 fn __next__(&mut self) -> Option<PyBranch> {
330 self.branches.next().map(|name| PyBranch {
331 path: self.path.clone(),
332 tree_name: self.tree_name.clone(),
333 name,
334 })
335 }
336}
337
338#[pymethods]
339impl PyBranch {
340 fn array(&self, py: Python) -> PyResult<Py<PyAny>> {
341 let mut file =
342 RootFile::open(&self.path).map_err(|e| PyValueError::new_err(e.to_string()))?;
343 let tree = file
344 .get_tree(&self.tree_name)
345 .map_err(|e| PyValueError::new_err(e.to_string()))?;
346 let branch = tree
347 .branch(&self.name)
348 .ok_or_else(|| PyValueError::new_err("Branch not found"))?;
349
350 match branch.item_type_name().as_str() {
351 "float" => {
352 let data = branch
353 .as_iter::<f32>()
354 .map_err(|e| PyValueError::new_err(e.to_string()))?
355 .collect::<Vec<_>>();
356 Ok(data.into_pyarray(py).into())
357 }
358 "double" => {
359 let data = branch
360 .as_iter::<f64>()
361 .map_err(|e| PyValueError::new_err(e.to_string()))?
362 .collect::<Vec<_>>();
363 Ok(data.into_pyarray(py).into())
364 }
365 "int32_t" => {
366 let data = branch
367 .as_iter::<i32>()
368 .map_err(|e| PyValueError::new_err(e.to_string()))?
369 .collect::<Vec<_>>();
370 Ok(data.into_pyarray(py).into())
371 }
372 "int64_t" => {
373 let data = branch
374 .as_iter::<i64>()
375 .map_err(|e| PyValueError::new_err(e.to_string()))?
376 .collect::<Vec<_>>();
377 Ok(data.into_pyarray(py).into())
378 }
379 "uint32_t" => {
380 let data = branch
381 .as_iter::<u32>()
382 .map_err(|e| PyValueError::new_err(e.to_string()))?
383 .collect::<Vec<_>>();
384 Ok(data.into_pyarray(py).into())
385 }
386 "uint64_t" => {
387 let data = branch
388 .as_iter::<u64>()
389 .map_err(|e| PyValueError::new_err(e.to_string()))?
390 .collect::<Vec<_>>();
391 Ok(data.into_pyarray(py).into())
392 }
393 "string" => {
394 let data = branch
395 .as_iter::<String>()
396 .map_err(|e| PyValueError::new_err(e.to_string()))?
397 .collect::<Vec<_>>();
398 Ok(data.into_py_any(py).unwrap())
399 }
400 other => Err(PyValueError::new_err(format!(
401 "Unsupported branch type: {}",
402 other
403 ))),
404 }
405 }
406
407 #[getter]
408 fn typename(&self) -> PyResult<String> {
409 let mut file =
410 RootFile::open(&self.path).map_err(|e| PyValueError::new_err(e.to_string()))?;
411 let tree = file
412 .get_tree(&self.tree_name)
413 .map_err(|e| PyValueError::new_err(e.to_string()))?;
414 let branch = tree
415 .branch(&self.name)
416 .ok_or_else(|| PyValueError::new_err("Branch not found"))?;
417 Ok(branch.item_type_name())
418 }
419}
420
421#[pyfunction]
422fn open(path: String) -> PyResult<PyRootFile> {
423 Ok(PyRootFile::new(path))
424}
425
426#[pyfunction]
427fn version() -> PyResult<String> {
428 Ok(env!("CARGO_PKG_VERSION").to_string())
429}
430
431#[pyfunction]
432#[pyo3(signature = (paths, tree_name, columns = None, ignore_columns = None))]
433fn concat_trees(
434 paths: Vec<String>,
435 tree_name: String,
436 columns: Option<Vec<String>>,
437 ignore_columns: Option<Vec<String>>,
438) -> PyResult<PyDataFrame> {
439 let mut all_paths = Vec::new();
440 for path in paths {
441 for entry in glob::glob(&path).map_err(|e| PyValueError::new_err(e.to_string()))? {
442 match entry {
443 Ok(path) => {
444 all_paths.push(path.to_str().unwrap().to_string());
445 }
446 Err(e) => return Err(PyValueError::new_err(e.to_string())),
447 }
448 }
449 }
450
451 let pool = POOL.lock();
452 let dfs: Vec<DataFrame> = pool.install(|| {
453 all_paths
454 .par_iter()
455 .map(|path| {
456 let mut file =
457 RootFile::open(path).map_err(|e| PyValueError::new_err(e.to_string()))?;
458 let tree = file
459 .get_tree(&tree_name)
460 .map_err(|e| PyValueError::new_err(e.to_string()))?;
461 tree_to_dataframe(&tree, columns.clone(), ignore_columns.clone())
462 })
463 .filter_map(Result::ok)
464 .collect()
465 });
466
467 if dfs.is_empty() {
468 return Ok(PyDataFrame(DataFrame::default()));
469 }
470
471 let combined_df = concat_df_diagonal(&dfs).map_err(|e| PyValueError::new_err(e.to_string()))?;
472
473 Ok(PyDataFrame(combined_df))
474}
475
476/// A Python module to read root files, implemented in Rust.
477#[pymodule]
478fn oxyroot(m: &Bound<'_, PyModule>) -> PyResult<()> {
479 m.add_function(wrap_pyfunction!(version, m)?)?;
480 m.add_function(wrap_pyfunction!(open, m)?)?;
481 m.add_function(wrap_pyfunction!(concat_trees, m)?)?;
482 m.add_function(wrap_pyfunction!(set_num_threads, m)?)?;
483 m.add_class::<PyRootFile>()?;
484 m.add_class::<PyTree>()?;
485 m.add_class::<PyBranch>()?;
486 m.add_class::<PyBranchIterator>()?;
487 Ok(())
488}