we (web engine): Experimental web browser project to understand the limits of Claude
1//! JPEG decoder (JFIF baseline DCT, ITU-T T.81).
2//!
3//! Decodes baseline JPEG images (SOF0) into RGBA8 pixel data. Supports
4//! 4:4:4, 4:2:2, and 4:2:0 chroma subsampling, Huffman entropy coding,
5//! and restart markers (DRI/RST).
6
7use crate::pixel::{Image, ImageError};
8
9// ---------------------------------------------------------------------------
10// JPEG marker constants (second byte after 0xFF prefix)
11// ---------------------------------------------------------------------------
12
13const MARKER_SOI: u8 = 0xD8;
14const MARKER_EOI: u8 = 0xD9;
15const MARKER_SOS: u8 = 0xDA;
16const MARKER_DQT: u8 = 0xDB;
17const MARKER_DHT: u8 = 0xC4;
18const MARKER_SOF0: u8 = 0xC0;
19const MARKER_DRI: u8 = 0xDD;
20
21fn decode_err(msg: &str) -> ImageError {
22 ImageError::Decode(msg.to_string())
23}
24
25// ---------------------------------------------------------------------------
26// Zigzag order table
27// ---------------------------------------------------------------------------
28
29/// Maps zigzag index (0..63) to natural 8x8 row-major index.
30const ZIGZAG: [usize; 64] = [
31 0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, 25, 18, 11, 4, 5, 12, 19, 26, 33, 40, 48, 41, 34, 27, 20,
32 13, 6, 7, 14, 21, 28, 35, 42, 49, 56, 57, 50, 43, 36, 29, 22, 15, 23, 30, 37, 44, 51, 58, 59,
33 52, 45, 38, 31, 39, 46, 53, 60, 61, 54, 47, 55, 62, 63,
34];
35
36// ---------------------------------------------------------------------------
37// Byte reader
38// ---------------------------------------------------------------------------
39
40struct JpegReader<'a> {
41 data: &'a [u8],
42 pos: usize,
43}
44
45impl<'a> JpegReader<'a> {
46 fn new(data: &'a [u8]) -> Self {
47 Self { data, pos: 0 }
48 }
49
50 fn remaining(&self) -> usize {
51 self.data.len().saturating_sub(self.pos)
52 }
53
54 fn read_byte(&mut self) -> Result<u8, ImageError> {
55 if self.pos >= self.data.len() {
56 return Err(decode_err("unexpected end of JPEG data"));
57 }
58 let b = self.data[self.pos];
59 self.pos += 1;
60 Ok(b)
61 }
62
63 fn read_u16_be(&mut self) -> Result<u16, ImageError> {
64 let hi = self.read_byte()? as u16;
65 let lo = self.read_byte()? as u16;
66 Ok((hi << 8) | lo)
67 }
68
69 fn read_bytes(&mut self, n: usize) -> Result<&'a [u8], ImageError> {
70 if self.pos + n > self.data.len() {
71 return Err(decode_err("unexpected end of JPEG data"));
72 }
73 let slice = &self.data[self.pos..self.pos + n];
74 self.pos += n;
75 Ok(slice)
76 }
77
78 fn skip(&mut self, n: usize) -> Result<(), ImageError> {
79 if self.pos + n > self.data.len() {
80 return Err(decode_err("unexpected end of JPEG data"));
81 }
82 self.pos += n;
83 Ok(())
84 }
85}
86
87// ---------------------------------------------------------------------------
88// Quantization table
89// ---------------------------------------------------------------------------
90
91struct QuantTable {
92 /// 64 quantization values in zigzag order.
93 values: [u16; 64],
94}
95
96fn parse_dqt(
97 reader: &mut JpegReader,
98 tables: &mut [Option<QuantTable>; 4],
99) -> Result<(), ImageError> {
100 let length = reader.read_u16_be()? as usize;
101 if length < 2 {
102 return Err(decode_err("DQT: invalid length"));
103 }
104 let mut remaining = length - 2;
105
106 while remaining > 0 {
107 let pq_tq = reader.read_byte()?;
108 remaining -= 1;
109 let precision = pq_tq >> 4;
110 let table_id = (pq_tq & 0x0F) as usize;
111 if table_id >= 4 {
112 return Err(decode_err("DQT: table id out of range"));
113 }
114
115 let mut values = [0u16; 64];
116 if precision == 0 {
117 // 8-bit values
118 if remaining < 64 {
119 return Err(decode_err("DQT: truncated 8-bit table"));
120 }
121 for v in &mut values {
122 *v = reader.read_byte()? as u16;
123 }
124 remaining -= 64;
125 } else if precision == 1 {
126 // 16-bit values
127 if remaining < 128 {
128 return Err(decode_err("DQT: truncated 16-bit table"));
129 }
130 for v in &mut values {
131 *v = reader.read_u16_be()?;
132 }
133 remaining -= 128;
134 } else {
135 return Err(decode_err("DQT: unsupported precision"));
136 }
137
138 tables[table_id] = Some(QuantTable { values });
139 }
140 Ok(())
141}
142
143// ---------------------------------------------------------------------------
144// Huffman table
145// ---------------------------------------------------------------------------
146
147struct HuffTable {
148 /// Number of codes of each length (1..=16).
149 counts: [u8; 16],
150 /// Symbol values in order of increasing code length.
151 symbols: Vec<u8>,
152 /// Index into symbols for the first code of each length.
153 val_offset: [i32; 16],
154 /// Maximum code value for each length (-1 if no codes).
155 max_code: [i32; 16],
156}
157
158impl HuffTable {
159 fn build(counts: [u8; 16], symbols: Vec<u8>) -> Self {
160 let mut max_code = [-1i32; 16];
161 let mut val_offset = [0i32; 16];
162
163 let mut code = 0i32;
164 let mut si = 0i32;
165 for i in 0..16 {
166 if counts[i] > 0 {
167 val_offset[i] = si - code;
168 code += counts[i] as i32;
169 max_code[i] = code - 1;
170 si += counts[i] as i32;
171 }
172 code <<= 1;
173 }
174
175 Self {
176 counts,
177 symbols,
178 val_offset,
179 max_code,
180 }
181 }
182}
183
184fn parse_dht(
185 reader: &mut JpegReader,
186 dc_tables: &mut [Option<HuffTable>; 4],
187 ac_tables: &mut [Option<HuffTable>; 4],
188) -> Result<(), ImageError> {
189 let length = reader.read_u16_be()? as usize;
190 if length < 2 {
191 return Err(decode_err("DHT: invalid length"));
192 }
193 let mut remaining = length - 2;
194
195 while remaining > 0 {
196 let tc_th = reader.read_byte()?;
197 remaining -= 1;
198 let table_class = tc_th >> 4; // 0=DC, 1=AC
199 let table_id = (tc_th & 0x0F) as usize;
200 if table_class > 1 || table_id >= 4 {
201 return Err(decode_err("DHT: invalid class or table id"));
202 }
203
204 if remaining < 16 {
205 return Err(decode_err("DHT: truncated counts"));
206 }
207 let mut counts = [0u8; 16];
208 for c in &mut counts {
209 *c = reader.read_byte()?;
210 }
211 remaining -= 16;
212
213 let total: usize = counts.iter().map(|&c| c as usize).sum();
214 if remaining < total {
215 return Err(decode_err("DHT: truncated symbols"));
216 }
217 let symbols = reader.read_bytes(total)?.to_vec();
218 remaining -= total;
219
220 let table = HuffTable::build(counts, symbols);
221 if table_class == 0 {
222 dc_tables[table_id] = Some(table);
223 } else {
224 ac_tables[table_id] = Some(table);
225 }
226 }
227 Ok(())
228}
229
230// ---------------------------------------------------------------------------
231// Frame header (SOF0)
232// ---------------------------------------------------------------------------
233
234struct ComponentInfo {
235 id: u8,
236 h_sample: u8,
237 v_sample: u8,
238 quant_table_id: u8,
239}
240
241struct FrameHeader {
242 height: u16,
243 width: u16,
244 components: Vec<ComponentInfo>,
245 h_max: u8,
246 v_max: u8,
247}
248
249fn parse_sof0(reader: &mut JpegReader) -> Result<FrameHeader, ImageError> {
250 let length = reader.read_u16_be()? as usize;
251 if length < 8 {
252 return Err(decode_err("SOF0: invalid length"));
253 }
254 let precision = reader.read_byte()?;
255 if precision != 8 {
256 return Err(decode_err("SOF0: only 8-bit precision supported"));
257 }
258 let height = reader.read_u16_be()?;
259 let width = reader.read_u16_be()?;
260 let num_components = reader.read_byte()? as usize;
261 if num_components == 0 || num_components > 4 {
262 return Err(decode_err("SOF0: invalid number of components"));
263 }
264 if length != 8 + 3 * num_components {
265 return Err(decode_err("SOF0: length mismatch"));
266 }
267
268 let mut components = Vec::with_capacity(num_components);
269 let mut h_max = 1u8;
270 let mut v_max = 1u8;
271 for _ in 0..num_components {
272 let id = reader.read_byte()?;
273 let hv = reader.read_byte()?;
274 let h_sample = hv >> 4;
275 let v_sample = hv & 0x0F;
276 if h_sample == 0 || h_sample > 4 || v_sample == 0 || v_sample > 4 {
277 return Err(decode_err("SOF0: invalid sampling factor"));
278 }
279 let quant_table_id = reader.read_byte()?;
280 h_max = h_max.max(h_sample);
281 v_max = v_max.max(v_sample);
282 components.push(ComponentInfo {
283 id,
284 h_sample,
285 v_sample,
286 quant_table_id,
287 });
288 }
289
290 if width == 0 || height == 0 {
291 return Err(decode_err("SOF0: zero dimension"));
292 }
293
294 Ok(FrameHeader {
295 height,
296 width,
297 components,
298 h_max,
299 v_max,
300 })
301}
302
303// ---------------------------------------------------------------------------
304// Scan header (SOS)
305// ---------------------------------------------------------------------------
306
307struct ScanComponentSelector {
308 component_index: usize,
309 dc_table_id: u8,
310 ac_table_id: u8,
311}
312
313struct ScanHeader {
314 components: Vec<ScanComponentSelector>,
315}
316
317fn parse_sos(reader: &mut JpegReader, frame: &FrameHeader) -> Result<ScanHeader, ImageError> {
318 let length = reader.read_u16_be()? as usize;
319 let num_components = reader.read_byte()? as usize;
320 if num_components == 0 || num_components > 4 {
321 return Err(decode_err("SOS: invalid number of components"));
322 }
323 if length != 6 + 2 * num_components {
324 return Err(decode_err("SOS: length mismatch"));
325 }
326
327 let mut components = Vec::with_capacity(num_components);
328 for _ in 0..num_components {
329 let cs = reader.read_byte()?;
330 let td_ta = reader.read_byte()?;
331 let dc_table_id = td_ta >> 4;
332 let ac_table_id = td_ta & 0x0F;
333
334 // Find component index by id
335 let component_index = frame
336 .components
337 .iter()
338 .position(|c| c.id == cs)
339 .ok_or_else(|| decode_err("SOS: unknown component id"))?;
340
341 components.push(ScanComponentSelector {
342 component_index,
343 dc_table_id,
344 ac_table_id,
345 });
346 }
347
348 // Spectral selection and successive approximation (must be 0, 63, 0 for baseline)
349 let _ss = reader.read_byte()?;
350 let _se = reader.read_byte()?;
351 let _ah_al = reader.read_byte()?;
352
353 Ok(ScanHeader { components })
354}
355
356// ---------------------------------------------------------------------------
357// Bit reader for entropy-coded data (MSB-first, with byte stuffing)
358// ---------------------------------------------------------------------------
359
360struct BitReader<'a> {
361 data: &'a [u8],
362 pos: usize,
363 bit_buf: u32,
364 bits_in_buf: u8,
365 /// Set when we encounter a real marker during reading.
366 marker_found: Option<u8>,
367}
368
369impl<'a> BitReader<'a> {
370 fn new(data: &'a [u8], start: usize) -> Self {
371 Self {
372 data,
373 pos: start,
374 bit_buf: 0,
375 bits_in_buf: 0,
376 marker_found: None,
377 }
378 }
379
380 /// Fill the bit buffer with at least `need` bits.
381 fn fill_bits(&mut self, need: u8) -> Result<(), ImageError> {
382 while self.bits_in_buf < need {
383 if self.pos >= self.data.len() {
384 return Err(decode_err("JPEG: unexpected end in entropy data"));
385 }
386 let b = self.data[self.pos];
387 self.pos += 1;
388
389 if b == 0xFF {
390 if self.pos >= self.data.len() {
391 return Err(decode_err("JPEG: unexpected end after 0xFF"));
392 }
393 let next = self.data[self.pos];
394 self.pos += 1;
395 if next == 0x00 {
396 // Byte stuffing: literal 0xFF
397 self.bit_buf = (self.bit_buf << 8) | 0xFF;
398 self.bits_in_buf += 8;
399 } else {
400 // Real marker found
401 self.marker_found = Some(next);
402 // Pad with zeros
403 self.bit_buf <<= 8;
404 self.bits_in_buf += 8;
405 }
406 } else {
407 self.bit_buf = (self.bit_buf << 8) | (b as u32);
408 self.bits_in_buf += 8;
409 }
410 }
411 Ok(())
412 }
413
414 fn read_bits(&mut self, count: u8) -> Result<u16, ImageError> {
415 if count == 0 {
416 return Ok(0);
417 }
418 self.fill_bits(count)?;
419 self.bits_in_buf -= count;
420 let val = (self.bit_buf >> self.bits_in_buf) & ((1 << count) - 1);
421 Ok(val as u16)
422 }
423
424 fn align_to_byte(&mut self) {
425 self.bits_in_buf = 0;
426 self.bit_buf = 0;
427 }
428
429 fn position(&self) -> usize {
430 self.pos
431 }
432}
433
434// ---------------------------------------------------------------------------
435// Huffman decoding
436// ---------------------------------------------------------------------------
437
438fn huff_decode(reader: &mut BitReader, table: &HuffTable) -> Result<u8, ImageError> {
439 let mut code = 0i32;
440 for i in 0..16 {
441 code = (code << 1) | reader.read_bits(1)? as i32;
442 if table.counts[i] > 0 && code <= table.max_code[i] {
443 let idx = (table.val_offset[i] + code) as usize;
444 if idx >= table.symbols.len() {
445 return Err(decode_err("Huffman: symbol index out of range"));
446 }
447 return Ok(table.symbols[idx]);
448 }
449 }
450 Err(decode_err("Huffman: invalid code"))
451}
452
453/// Extend a value to a signed integer based on the JPEG sign convention.
454fn extend(value: u16, bits: u8) -> i32 {
455 if bits == 0 {
456 return 0;
457 }
458 let vt = 1i32 << (bits - 1);
459 let v = value as i32;
460 if v < vt {
461 v + (-1 << bits) + 1
462 } else {
463 v
464 }
465}
466
467fn decode_dc(reader: &mut BitReader, table: &HuffTable) -> Result<i32, ImageError> {
468 let category = huff_decode(reader, table)?;
469 if category == 0 {
470 return Ok(0);
471 }
472 if category > 15 {
473 return Err(decode_err("DC: invalid category"));
474 }
475 let bits = reader.read_bits(category)?;
476 Ok(extend(bits, category))
477}
478
479fn decode_ac(
480 reader: &mut BitReader,
481 table: &HuffTable,
482 block: &mut [i32; 64],
483) -> Result<(), ImageError> {
484 let mut k = 1usize;
485 while k < 64 {
486 let rs = huff_decode(reader, table)?;
487 let run = (rs >> 4) as usize;
488 let category = rs & 0x0F;
489
490 if category == 0 {
491 if run == 0 {
492 // EOB: fill rest with zeros
493 while k < 64 {
494 block[k] = 0;
495 k += 1;
496 }
497 return Ok(());
498 } else if run == 0x0F {
499 // ZRL: 16 zeros
500 for _ in 0..16 {
501 if k < 64 {
502 block[k] = 0;
503 k += 1;
504 }
505 }
506 continue;
507 } else {
508 return Err(decode_err("AC: invalid run/category combination"));
509 }
510 }
511
512 // Skip `run` zeros
513 for _ in 0..run {
514 if k < 64 {
515 block[k] = 0;
516 k += 1;
517 }
518 }
519
520 if k >= 64 {
521 return Err(decode_err("AC: coefficient index out of range"));
522 }
523
524 let bits = reader.read_bits(category)?;
525 block[k] = extend(bits, category);
526 k += 1;
527 }
528 Ok(())
529}
530
531// ---------------------------------------------------------------------------
532// Dequantization
533// ---------------------------------------------------------------------------
534
535fn dequantize(block: &mut [i32; 64], quant: &QuantTable) {
536 for (b, &q) in block.iter_mut().zip(quant.values.iter()) {
537 *b *= q as i32;
538 }
539}
540
541// ---------------------------------------------------------------------------
542// Inverse DCT
543// ---------------------------------------------------------------------------
544
545/// Reorder from zigzag to natural 8x8 row-major order.
546fn unzigzag(zigzag: &[i32; 64]) -> [i32; 64] {
547 let mut natural = [0i32; 64];
548 for i in 0..64 {
549 natural[ZIGZAG[i]] = zigzag[i];
550 }
551 natural
552}
553
554/// 1D IDCT on 8 values using the Loeffler/Ligtenberg/Moschytz algorithm.
555/// Fixed-point with 12 bits of fractional precision.
556///
557/// Constants are scaled: C_k = cos(k*pi/16) * 2^12, rounded.
558const FIX_0_298: i32 = 2446; // cos(7*pi/16) * 4096
559const FIX_0_390: i32 = 3196; // sqrt(2) * (cos(6*pi/16) - cos(2*pi/16)) * 2048 -- see below
560const FIX_0_541: i32 = 4433; // sqrt(2) * cos(6*pi/16) * 4096
561const FIX_0_765: i32 = 6270; // sqrt(2) * cos(2*pi/16) - sqrt(2) * cos(6*pi/16)
562const FIX_1_175: i32 = 9633; // sqrt(2) * cos(pi/8) -- used in stage 1 butterfly
563const FIX_1_501: i32 = 12299; // sqrt(2) * (cos(pi/16) - cos(7*pi/16))
564const FIX_1_847: i32 = 15137; // sqrt(2) * cos(3*pi/16)
565const FIX_1_961: i32 = 16069; // sqrt(2) * (cos(3*pi/16) + cos(5*pi/16)) -- negative
566const FIX_2_053: i32 = 16819; // sqrt(2) * (cos(pi/16) + cos(7*pi/16))
567const FIX_2_562: i32 = 20995; // sqrt(2) * (cos(3*pi/16) - cos(5*pi/16)) -- negative
568const FIX_3_072: i32 = 25172; // sqrt(2) * (cos(pi/16) + cos(3*pi/16))
569
570const CONST_BITS: i32 = 13;
571const PASS1_BITS: i32 = 2;
572
573/// Perform the 2D IDCT and produce 64 pixel values (clamped to 0..255).
574/// Input is in natural 8x8 row-major order, dequantized.
575fn idct_block(coeffs: &[i32; 64]) -> [u8; 64] {
576 let mut workspace = [0i32; 64];
577
578 // Pass 1: process columns from input, store into workspace.
579 for col in 0..8 {
580 // If all AC terms are zero, short-circuit.
581 if coeffs[col + 8] == 0
582 && coeffs[col + 16] == 0
583 && coeffs[col + 24] == 0
584 && coeffs[col + 32] == 0
585 && coeffs[col + 40] == 0
586 && coeffs[col + 48] == 0
587 && coeffs[col + 56] == 0
588 {
589 let dcval = coeffs[col] << PASS1_BITS;
590 for row in 0..8 {
591 workspace[row * 8 + col] = dcval;
592 }
593 continue;
594 }
595
596 // Even part: use the Loeffler method
597 let z2 = coeffs[col + 16];
598 let z3 = coeffs[col + 48];
599
600 let z1 = (z2 + z3) * FIX_0_541;
601 let tmp2 = z1 + z3 * (-FIX_1_847);
602 let tmp3 = z1 + z2 * FIX_0_765;
603
604 let z2 = coeffs[col];
605 let z3 = coeffs[col + 32];
606
607 let tmp0 = (z2 + z3) << CONST_BITS;
608 let tmp1 = (z2 - z3) << CONST_BITS;
609
610 let tmp10 = tmp0 + tmp3;
611 let tmp13 = tmp0 - tmp3;
612 let tmp11 = tmp1 + tmp2;
613 let tmp12 = tmp1 - tmp2;
614
615 // Odd part
616 let tmp0 = coeffs[col + 56];
617 let tmp1 = coeffs[col + 40];
618 let tmp2 = coeffs[col + 24];
619 let tmp3 = coeffs[col + 8];
620
621 let z1 = tmp0 + tmp3;
622 let z2 = tmp1 + tmp2;
623 let z3 = tmp0 + tmp2;
624 let z4 = tmp1 + tmp3;
625 let z5 = (z3 + z4) * FIX_1_175;
626
627 let tmp0 = tmp0 * FIX_0_298;
628 let tmp1 = tmp1 * FIX_2_053;
629 let tmp2 = tmp2 * FIX_3_072;
630 let tmp3 = tmp3 * FIX_1_501;
631 let z1 = z1 * (-FIX_0_390);
632 let z2 = z2 * (-FIX_2_562);
633 let z3 = z3 * (-FIX_1_961);
634 let z4 = z4 * (-FIX_0_298);
635
636 let z3 = z3 + z5;
637 let z4 = z4 + z5;
638
639 let tmp0 = tmp0 + z1 + z3;
640 let tmp1 = tmp1 + z2 + z4;
641 let tmp2 = tmp2 + z2 + z3;
642 let tmp3 = tmp3 + z1 + z4;
643
644 let shift = CONST_BITS - PASS1_BITS;
645 workspace[col] = (tmp10 + tmp3 + (1 << (shift - 1))) >> shift;
646 workspace[col + 56] = (tmp10 - tmp3 + (1 << (shift - 1))) >> shift;
647 workspace[col + 8] = (tmp11 + tmp2 + (1 << (shift - 1))) >> shift;
648 workspace[col + 48] = (tmp11 - tmp2 + (1 << (shift - 1))) >> shift;
649 workspace[col + 16] = (tmp12 + tmp1 + (1 << (shift - 1))) >> shift;
650 workspace[col + 40] = (tmp12 - tmp1 + (1 << (shift - 1))) >> shift;
651 workspace[col + 24] = (tmp13 + tmp0 + (1 << (shift - 1))) >> shift;
652 workspace[col + 32] = (tmp13 - tmp0 + (1 << (shift - 1))) >> shift;
653 }
654
655 // Pass 2: process rows from workspace, produce output.
656 let mut output = [0u8; 64];
657 for row in 0..8 {
658 let base = row * 8;
659
660 // Short-circuit for all-zero AC
661 if workspace[base + 1] == 0
662 && workspace[base + 2] == 0
663 && workspace[base + 3] == 0
664 && workspace[base + 4] == 0
665 && workspace[base + 5] == 0
666 && workspace[base + 6] == 0
667 && workspace[base + 7] == 0
668 {
669 let dcval = clamp_to_u8(
670 ((workspace[base] + (1 << (PASS1_BITS + 2))) >> (PASS1_BITS + 3)) + 128,
671 );
672 for col in 0..8 {
673 output[base + col] = dcval;
674 }
675 continue;
676 }
677
678 let z2 = workspace[base + 2];
679 let z3 = workspace[base + 6];
680
681 let z1 = (z2 + z3) * FIX_0_541;
682 let tmp2 = z1 + z3 * (-FIX_1_847);
683 let tmp3 = z1 + z2 * FIX_0_765;
684
685 let z2 = workspace[base];
686 let z3 = workspace[base + 4];
687
688 let tmp0 = (z2 + z3) << CONST_BITS;
689 let tmp1 = (z2 - z3) << CONST_BITS;
690
691 let tmp10 = tmp0 + tmp3;
692 let tmp13 = tmp0 - tmp3;
693 let tmp11 = tmp1 + tmp2;
694 let tmp12 = tmp1 - tmp2;
695
696 let tmp0 = workspace[base + 7];
697 let tmp1 = workspace[base + 5];
698 let tmp2 = workspace[base + 3];
699 let tmp3 = workspace[base + 1];
700
701 let z1 = tmp0 + tmp3;
702 let z2 = tmp1 + tmp2;
703 let z3 = tmp0 + tmp2;
704 let z4 = tmp1 + tmp3;
705 let z5 = (z3 + z4) * FIX_1_175;
706
707 let tmp0 = tmp0 * FIX_0_298;
708 let tmp1 = tmp1 * FIX_2_053;
709 let tmp2 = tmp2 * FIX_3_072;
710 let tmp3 = tmp3 * FIX_1_501;
711 let z1 = z1 * (-FIX_0_390);
712 let z2 = z2 * (-FIX_2_562);
713 let z3 = z3 * (-FIX_1_961);
714 let z4 = z4 * (-FIX_0_298);
715
716 let z3 = z3 + z5;
717 let z4 = z4 + z5;
718
719 let tmp0 = tmp0 + z1 + z3;
720 let tmp1 = tmp1 + z2 + z4;
721 let tmp2 = tmp2 + z2 + z3;
722 let tmp3 = tmp3 + z1 + z4;
723
724 let shift = CONST_BITS + PASS1_BITS + 3;
725 let round = 1 << (shift - 1);
726 output[base] = clamp_to_u8(((tmp10 + tmp3 + round) >> shift) + 128);
727 output[base + 7] = clamp_to_u8(((tmp10 - tmp3 + round) >> shift) + 128);
728 output[base + 1] = clamp_to_u8(((tmp11 + tmp2 + round) >> shift) + 128);
729 output[base + 6] = clamp_to_u8(((tmp11 - tmp2 + round) >> shift) + 128);
730 output[base + 2] = clamp_to_u8(((tmp12 + tmp1 + round) >> shift) + 128);
731 output[base + 5] = clamp_to_u8(((tmp12 - tmp1 + round) >> shift) + 128);
732 output[base + 3] = clamp_to_u8(((tmp13 + tmp0 + round) >> shift) + 128);
733 output[base + 4] = clamp_to_u8(((tmp13 - tmp0 + round) >> shift) + 128);
734 }
735
736 output
737}
738
739fn clamp_to_u8(val: i32) -> u8 {
740 val.clamp(0, 255) as u8
741}
742
743// ---------------------------------------------------------------------------
744// MCU-based scan decoding
745// ---------------------------------------------------------------------------
746
747/// Decode all MCUs from entropy-coded data.
748/// Returns per-component sample buffers (at component resolution).
749fn decode_scan(
750 reader: &mut BitReader,
751 frame: &FrameHeader,
752 scan: &ScanHeader,
753 quant_tables: &[Option<QuantTable>; 4],
754 dc_tables: &[Option<HuffTable>; 4],
755 ac_tables: &[Option<HuffTable>; 4],
756 restart_interval: u16,
757) -> Result<Vec<Vec<u8>>, ImageError> {
758 let h_max = frame.h_max as usize;
759 let v_max = frame.v_max as usize;
760
761 // MCU dimensions in pixels
762 let mcu_w = h_max * 8;
763 let mcu_h = v_max * 8;
764
765 // Number of MCUs in each direction
766 let mcus_x = (frame.width as usize).div_ceil(mcu_w);
767 let mcus_y = (frame.height as usize).div_ceil(mcu_h);
768
769 // Allocate per-component sample buffers
770 let mut comp_buffers: Vec<Vec<u8>> = Vec::with_capacity(frame.components.len());
771 let mut comp_widths: Vec<usize> = Vec::with_capacity(frame.components.len());
772 let mut comp_heights: Vec<usize> = Vec::with_capacity(frame.components.len());
773
774 for comp in &frame.components {
775 let cw = mcus_x * comp.h_sample as usize * 8;
776 let ch = mcus_y * comp.v_sample as usize * 8;
777 comp_buffers.push(vec![0u8; cw * ch]);
778 comp_widths.push(cw);
779 comp_heights.push(ch);
780 }
781
782 // DC predictors per component
783 let mut dc_pred = vec![0i32; frame.components.len()];
784
785 let mut mcu_count = 0u32;
786 let restart_interval = restart_interval as u32;
787
788 for mcu_y in 0..mcus_y {
789 for mcu_x in 0..mcus_x {
790 // Handle restart marker
791 if restart_interval > 0 && mcu_count > 0 && mcu_count.is_multiple_of(restart_interval) {
792 reader.align_to_byte();
793 // Skip past the restart marker
794 // The marker may already have been found by the bit reader
795 if reader.marker_found.is_some() {
796 reader.marker_found = None;
797 }
798 // Reset DC predictors
799 dc_pred.fill(0);
800 }
801
802 for scan_comp in &scan.components {
803 let ci = scan_comp.component_index;
804 let comp = &frame.components[ci];
805
806 let dc_table = dc_tables[scan_comp.dc_table_id as usize]
807 .as_ref()
808 .ok_or_else(|| decode_err("missing DC Huffman table"))?;
809 let ac_table = ac_tables[scan_comp.ac_table_id as usize]
810 .as_ref()
811 .ok_or_else(|| decode_err("missing AC Huffman table"))?;
812 let quant = quant_tables[comp.quant_table_id as usize]
813 .as_ref()
814 .ok_or_else(|| decode_err("missing quantization table"))?;
815
816 // Each component contributes h_sample * v_sample blocks per MCU
817 for bv in 0..comp.v_sample as usize {
818 for bh in 0..comp.h_sample as usize {
819 // Decode one 8x8 block
820 let mut block = [0i32; 64];
821
822 // DC
823 let dc_diff = decode_dc(reader, dc_table)?;
824 dc_pred[ci] += dc_diff;
825 block[0] = dc_pred[ci];
826
827 // AC
828 decode_ac(reader, ac_table, &mut block)?;
829
830 // Dequantize (in zigzag order)
831 dequantize(&mut block, quant);
832
833 // Unzigzag to natural order
834 let natural = unzigzag(&block);
835
836 // IDCT
837 let pixels = idct_block(&natural);
838
839 // Write block to component buffer
840 let block_x = mcu_x * comp.h_sample as usize * 8 + bh * 8;
841 let block_y = mcu_y * comp.v_sample as usize * 8 + bv * 8;
842 let cw = comp_widths[ci];
843
844 for row in 0..8 {
845 let dst_y = block_y + row;
846 if dst_y < comp_heights[ci] {
847 let dst_offset = dst_y * cw + block_x;
848 for col in 0..8 {
849 if block_x + col < cw {
850 comp_buffers[ci][dst_offset + col] = pixels[row * 8 + col];
851 }
852 }
853 }
854 }
855 }
856 }
857 }
858 mcu_count += 1;
859 }
860 }
861
862 Ok(comp_buffers)
863}
864
865// ---------------------------------------------------------------------------
866// Chroma upsampling
867// ---------------------------------------------------------------------------
868
869#[allow(clippy::too_many_arguments)]
870fn upsample(
871 samples: &[u8],
872 sample_width: usize,
873 sample_height: usize,
874 h_sample: u8,
875 v_sample: u8,
876 h_max: u8,
877 v_max: u8,
878 target_width: usize,
879 target_height: usize,
880) -> Vec<u8> {
881 let h_ratio = (h_max / h_sample) as usize;
882 let v_ratio = (v_max / v_sample) as usize;
883
884 if h_ratio == 1 && v_ratio == 1 {
885 // No upsampling needed — just crop if needed
886 if sample_width == target_width && sample_height == target_height {
887 return samples.to_vec();
888 }
889 let mut out = vec![0u8; target_width * target_height];
890 for y in 0..target_height {
891 for x in 0..target_width {
892 out[y * target_width + x] = samples[y * sample_width + x];
893 }
894 }
895 return out;
896 }
897
898 let mut out = vec![0u8; target_width * target_height];
899 for y in 0..target_height {
900 let sy = (y / v_ratio).min(sample_height - 1);
901 for x in 0..target_width {
902 let sx = (x / h_ratio).min(sample_width - 1);
903 out[y * target_width + x] = samples[sy * sample_width + sx];
904 }
905 }
906 out
907}
908
909// ---------------------------------------------------------------------------
910// YCbCr to RGB conversion
911// ---------------------------------------------------------------------------
912
913fn ycbcr_to_rgba(
914 y_samples: &[u8],
915 cb_samples: &[u8],
916 cr_samples: &[u8],
917 width: usize,
918 height: usize,
919) -> Vec<u8> {
920 let mut rgba = Vec::with_capacity(width * height * 4);
921 for i in 0..width * height {
922 let y = y_samples[i] as i32;
923 let cb = cb_samples[i] as i32 - 128;
924 let cr = cr_samples[i] as i32 - 128;
925
926 // Fixed-point YCbCr->RGB (BT.601)
927 let r = y + ((cr * 359 + 128) >> 8);
928 let g = y - ((cb * 88 + cr * 183 + 128) >> 8);
929 let b = y + ((cb * 454 + 128) >> 8);
930
931 rgba.push(r.clamp(0, 255) as u8);
932 rgba.push(g.clamp(0, 255) as u8);
933 rgba.push(b.clamp(0, 255) as u8);
934 rgba.push(255);
935 }
936 rgba
937}
938
939// ---------------------------------------------------------------------------
940// Public API
941// ---------------------------------------------------------------------------
942
943/// Decode a JPEG image into an RGBA8 `Image`.
944///
945/// Supports baseline DCT (SOF0) with Huffman coding, 4:4:4/4:2:2/4:2:0
946/// chroma subsampling, and restart markers.
947pub fn decode_jpeg(data: &[u8]) -> Result<Image, ImageError> {
948 let mut reader = JpegReader::new(data);
949
950 // Verify SOI
951 if reader.remaining() < 2 {
952 return Err(decode_err("JPEG: too short"));
953 }
954 let soi1 = reader.read_byte()?;
955 let soi2 = reader.read_byte()?;
956 if soi1 != 0xFF || soi2 != MARKER_SOI {
957 return Err(decode_err("JPEG: missing SOI marker"));
958 }
959
960 let mut quant_tables: [Option<QuantTable>; 4] = [None, None, None, None];
961 let mut dc_tables: [Option<HuffTable>; 4] = [None, None, None, None];
962 let mut ac_tables: [Option<HuffTable>; 4] = [None, None, None, None];
963 let mut frame: Option<FrameHeader> = None;
964 let mut restart_interval: u16 = 0;
965 let mut comp_buffers: Option<Vec<Vec<u8>>> = None;
966
967 loop {
968 // Find next marker
969 let mut b = reader.read_byte()?;
970 if b != 0xFF {
971 // Sometimes there is padding; scan for 0xFF
972 while b != 0xFF {
973 if reader.remaining() == 0 {
974 return Err(decode_err("JPEG: unexpected end searching for marker"));
975 }
976 b = reader.read_byte()?;
977 }
978 }
979 // Skip fill bytes (multiple 0xFF)
980 let mut marker = reader.read_byte()?;
981 while marker == 0xFF {
982 marker = reader.read_byte()?;
983 }
984 if marker == 0x00 {
985 continue; // Stuffed byte outside scan, ignore
986 }
987
988 match marker {
989 MARKER_EOI => break,
990
991 MARKER_SOF0 => {
992 frame = Some(parse_sof0(&mut reader)?);
993 }
994
995 // Reject progressive/lossless/etc
996 0xC1..=0xC3 | 0xC5..=0xC7 | 0xC9..=0xCB | 0xCD..=0xCF => {
997 return Err(decode_err("JPEG: only baseline DCT (SOF0) is supported"));
998 }
999
1000 MARKER_DHT => {
1001 parse_dht(&mut reader, &mut dc_tables, &mut ac_tables)?;
1002 }
1003
1004 MARKER_DQT => {
1005 parse_dqt(&mut reader, &mut quant_tables)?;
1006 }
1007
1008 MARKER_DRI => {
1009 let _len = reader.read_u16_be()?;
1010 restart_interval = reader.read_u16_be()?;
1011 }
1012
1013 MARKER_SOS => {
1014 let f = frame
1015 .as_ref()
1016 .ok_or_else(|| decode_err("JPEG: SOS before SOF"))?;
1017 let scan = parse_sos(&mut reader, f)?;
1018
1019 let mut bit_reader = BitReader::new(reader.data, reader.pos);
1020 comp_buffers = Some(decode_scan(
1021 &mut bit_reader,
1022 f,
1023 &scan,
1024 &quant_tables,
1025 &dc_tables,
1026 &ac_tables,
1027 restart_interval,
1028 )?);
1029 reader.pos = bit_reader.position();
1030 // If the bit reader found a marker, back up so the outer loop sees it
1031 if let Some(m) = bit_reader.marker_found {
1032 if m == MARKER_EOI {
1033 break;
1034 }
1035 // Back up to the 0xFF before the marker
1036 reader.pos -= 1;
1037 // We need to also account for the 0xFF
1038 if reader.pos > 0 {
1039 reader.pos -= 1;
1040 }
1041 }
1042 }
1043
1044 // APP0..APP15, COM: skip
1045 _ => {
1046 if reader.remaining() >= 2 {
1047 let len = reader.read_u16_be()? as usize;
1048 if len >= 2 {
1049 reader.skip(len - 2)?;
1050 }
1051 }
1052 }
1053 }
1054 }
1055
1056 let f = frame.ok_or_else(|| decode_err("JPEG: no SOF0 frame found"))?;
1057 let buffers = comp_buffers.ok_or_else(|| decode_err("JPEG: no scan data"))?;
1058
1059 let width = f.width as usize;
1060 let height = f.height as usize;
1061
1062 if f.components.len() == 1 {
1063 // Grayscale
1064 let h_max = f.h_max;
1065 let v_max = f.v_max;
1066 let comp = &f.components[0];
1067 let mcus_x = width.div_ceil(h_max as usize * 8);
1068 let mcus_y = height.div_ceil(v_max as usize * 8);
1069 let cw = mcus_x * comp.h_sample as usize * 8;
1070 let ch = mcus_y * comp.v_sample as usize * 8;
1071
1072 let gray = upsample(
1073 &buffers[0],
1074 cw,
1075 ch,
1076 comp.h_sample,
1077 comp.v_sample,
1078 h_max,
1079 v_max,
1080 width,
1081 height,
1082 );
1083 crate::pixel::from_grayscale(width as u32, height as u32, &gray)
1084 } else if f.components.len() == 3 {
1085 // YCbCr
1086 let h_max = f.h_max;
1087 let v_max = f.v_max;
1088
1089 let mut upsampled = Vec::with_capacity(3);
1090 for (ci, comp) in f.components.iter().enumerate() {
1091 let mcus_x = width.div_ceil(h_max as usize * 8);
1092 let mcus_y = height.div_ceil(v_max as usize * 8);
1093 let cw = mcus_x * comp.h_sample as usize * 8;
1094 let ch = mcus_y * comp.v_sample as usize * 8;
1095
1096 upsampled.push(upsample(
1097 &buffers[ci],
1098 cw,
1099 ch,
1100 comp.h_sample,
1101 comp.v_sample,
1102 h_max,
1103 v_max,
1104 width,
1105 height,
1106 ));
1107 }
1108
1109 let rgba = ycbcr_to_rgba(&upsampled[0], &upsampled[1], &upsampled[2], width, height);
1110 Image::new(width as u32, height as u32, rgba)
1111 } else {
1112 Err(decode_err("JPEG: unsupported number of components"))
1113 }
1114}
1115
1116// ---------------------------------------------------------------------------
1117// Tests
1118// ---------------------------------------------------------------------------
1119
1120#[cfg(test)]
1121mod tests {
1122 use super::*;
1123
1124 // -- Zigzag table --
1125
1126 #[test]
1127 fn zigzag_covers_all_indices() {
1128 let mut seen = [false; 64];
1129 for &idx in &ZIGZAG {
1130 assert!(idx < 64);
1131 seen[idx] = true;
1132 }
1133 assert!(seen.iter().all(|&s| s), "ZIGZAG must cover 0..63");
1134 }
1135
1136 #[test]
1137 fn zigzag_known_positions() {
1138 assert_eq!(ZIGZAG[0], 0);
1139 assert_eq!(ZIGZAG[1], 1);
1140 assert_eq!(ZIGZAG[2], 8);
1141 assert_eq!(ZIGZAG[3], 16);
1142 assert_eq!(ZIGZAG[63], 63);
1143 }
1144
1145 // -- Unzigzag --
1146
1147 #[test]
1148 fn unzigzag_dc_only() {
1149 let mut zigzag = [0i32; 64];
1150 zigzag[0] = 100;
1151 let natural = unzigzag(&zigzag);
1152 assert_eq!(natural[0], 100);
1153 for i in 1..64 {
1154 assert_eq!(natural[i], 0);
1155 }
1156 }
1157
1158 // -- Extend (sign extension) --
1159
1160 #[test]
1161 fn extend_values() {
1162 // Category 1: values 0,1 -> -1, 1
1163 assert_eq!(extend(0, 1), -1);
1164 assert_eq!(extend(1, 1), 1);
1165 // Category 2: values 0,1,2,3 -> -3,-2,2,3
1166 assert_eq!(extend(0, 2), -3);
1167 assert_eq!(extend(1, 2), -2);
1168 assert_eq!(extend(2, 2), 2);
1169 assert_eq!(extend(3, 2), 3);
1170 // Category 0
1171 assert_eq!(extend(0, 0), 0);
1172 }
1173
1174 // -- IDCT: DC only block --
1175
1176 #[test]
1177 fn idct_dc_only() {
1178 let mut coeffs = [0i32; 64];
1179 coeffs[0] = 128; // DC coefficient
1180 let pixels = idct_block(&coeffs);
1181 // All pixels should be DC/8 + 128 = 128/8 + 128 = 144
1182 // (IDCT of a DC-only block divides by 8 in each dimension)
1183 let expected = 128 + (128 / 8);
1184 for &p in &pixels {
1185 // Allow +-1 for rounding
1186 assert!(
1187 (p as i32 - expected).unsigned_abs() <= 1,
1188 "expected ~{expected}, got {p}"
1189 );
1190 }
1191 }
1192
1193 #[test]
1194 fn idct_zero_block() {
1195 let coeffs = [0i32; 64];
1196 let pixels = idct_block(&coeffs);
1197 // All zeros + level shift of 128
1198 for &p in &pixels {
1199 assert_eq!(p, 128);
1200 }
1201 }
1202
1203 // -- YCbCr to RGB --
1204
1205 #[test]
1206 fn ycbcr_white() {
1207 let rgba = ycbcr_to_rgba(&[255], &[128], &[128], 1, 1);
1208 assert_eq!(rgba[0], 255); // R
1209 assert_eq!(rgba[1], 255); // G
1210 assert_eq!(rgba[2], 255); // B
1211 assert_eq!(rgba[3], 255); // A
1212 }
1213
1214 #[test]
1215 fn ycbcr_black() {
1216 let rgba = ycbcr_to_rgba(&[0], &[128], &[128], 1, 1);
1217 assert_eq!(rgba[0], 0); // R
1218 assert_eq!(rgba[1], 0); // G
1219 assert_eq!(rgba[2], 0); // B
1220 assert_eq!(rgba[3], 255); // A
1221 }
1222
1223 #[test]
1224 fn ycbcr_gray_128() {
1225 let rgba = ycbcr_to_rgba(&[128], &[128], &[128], 1, 1);
1226 assert_eq!(rgba[0], 128); // R
1227 assert_eq!(rgba[1], 128); // G
1228 assert_eq!(rgba[2], 128); // B
1229 }
1230
1231 // -- Huffman table build and decode --
1232
1233 #[test]
1234 fn huffman_build_and_decode() {
1235 // Simple Huffman table: 2 symbols
1236 // Symbol A has code 0 (1 bit), symbol B has code 1 (1 bit)
1237 let mut counts = [0u8; 16];
1238 counts[0] = 2; // 2 codes of length 1
1239 let symbols = vec![b'A', b'B'];
1240 let table = HuffTable::build(counts, symbols);
1241
1242 // Encode "ABA" as bits: 0, 1, 0 = 0b010_00000 = 0x40
1243 let data = [0x40u8];
1244 let mut reader = BitReader::new(&data, 0);
1245 assert_eq!(huff_decode(&mut reader, &table).unwrap(), b'A');
1246 assert_eq!(huff_decode(&mut reader, &table).unwrap(), b'B');
1247 assert_eq!(huff_decode(&mut reader, &table).unwrap(), b'A');
1248 }
1249
1250 #[test]
1251 fn huffman_multi_length() {
1252 // 3 symbols: A=0 (1 bit), B=10 (2 bits), C=11 (2 bits)
1253 let mut counts = [0u8; 16];
1254 counts[0] = 1; // 1 code of length 1
1255 counts[1] = 2; // 2 codes of length 2
1256 let symbols = vec![b'A', b'B', b'C'];
1257 let table = HuffTable::build(counts, symbols);
1258
1259 // "ABCA" = 0, 10, 11, 0 = 0b_0_10_11_0_00 = 0x58 (MSB first)
1260 let data = [0x58u8];
1261 let mut reader = BitReader::new(&data, 0);
1262 assert_eq!(huff_decode(&mut reader, &table).unwrap(), b'A');
1263 assert_eq!(huff_decode(&mut reader, &table).unwrap(), b'B');
1264 assert_eq!(huff_decode(&mut reader, &table).unwrap(), b'C');
1265 assert_eq!(huff_decode(&mut reader, &table).unwrap(), b'A');
1266 }
1267
1268 // -- Bit reader byte stuffing --
1269
1270 #[test]
1271 fn bit_reader_stuffing() {
1272 // 0xFF 0x00 should produce a literal 0xFF byte
1273 let data = [0xFF, 0x00, 0x80];
1274 let mut reader = BitReader::new(&data, 0);
1275 // Read 8 bits -> should get 0xFF
1276 let val = reader.read_bits(8).unwrap();
1277 assert_eq!(val, 0xFF);
1278 // Read 8 more bits -> 0x80
1279 let val = reader.read_bits(8).unwrap();
1280 assert_eq!(val, 0x80);
1281 }
1282
1283 // -- Dequantize --
1284
1285 #[test]
1286 fn dequantize_basic() {
1287 let mut block = [0i32; 64];
1288 block[0] = 10;
1289 block[1] = -5;
1290 let quant = QuantTable {
1291 values: {
1292 let mut v = [1u16; 64];
1293 v[0] = 16;
1294 v[1] = 11;
1295 v
1296 },
1297 };
1298 dequantize(&mut block, &quant);
1299 assert_eq!(block[0], 160);
1300 assert_eq!(block[1], -55);
1301 }
1302
1303 // -- Upsampling --
1304
1305 #[test]
1306 fn upsample_identity() {
1307 let samples = vec![1, 2, 3, 4];
1308 let result = upsample(&samples, 2, 2, 2, 2, 2, 2, 2, 2);
1309 assert_eq!(result, samples);
1310 }
1311
1312 #[test]
1313 fn upsample_2x_horizontal() {
1314 // 1x2 upsampled to 2x2
1315 let samples = vec![10, 20];
1316 let result = upsample(&samples, 2, 1, 1, 1, 2, 1, 4, 1);
1317 assert_eq!(result, vec![10, 10, 20, 20]);
1318 }
1319
1320 #[test]
1321 fn upsample_2x_both() {
1322 // 1x1 upsampled to 2x2
1323 let samples = vec![42];
1324 let result = upsample(&samples, 1, 1, 1, 1, 2, 2, 2, 2);
1325 assert_eq!(result, vec![42, 42, 42, 42]);
1326 }
1327
1328 // -- Minimal JPEG construction and decoding --
1329
1330 /// Build a minimal 1-component (grayscale) 8x8 JPEG.
1331 fn build_minimal_grayscale_jpeg(dc_value: i32) -> Vec<u8> {
1332 let mut out = Vec::new();
1333
1334 // SOI
1335 out.push(0xFF);
1336 out.push(MARKER_SOI);
1337
1338 // DQT: one table, id 0, all values = 1
1339 out.push(0xFF);
1340 out.push(MARKER_DQT);
1341 let dqt_len: u16 = 2 + 1 + 64;
1342 out.push((dqt_len >> 8) as u8);
1343 out.push(dqt_len as u8);
1344 out.push(0x00); // precision=0 (8-bit), table_id=0
1345 for _ in 0..64 {
1346 out.push(1); // All quantization values = 1
1347 }
1348
1349 // SOF0: 1 component, 8x8
1350 out.push(0xFF);
1351 out.push(MARKER_SOF0);
1352 let sof_len: u16 = 2 + 1 + 2 + 2 + 1 + 3;
1353 out.push((sof_len >> 8) as u8);
1354 out.push(sof_len as u8);
1355 out.push(8); // precision
1356 out.push(0);
1357 out.push(8); // height=8
1358 out.push(0);
1359 out.push(8); // width=8
1360 out.push(1); // 1 component
1361 out.push(1); // component id=1
1362 out.push(0x11); // H=1, V=1
1363 out.push(0); // quant table 0
1364
1365 // DHT: DC table 0
1366 // Simple: category 0 has code "00" (2 bits), category `cat` has code "01..." etc.
1367 // For minimal test: we only need DC category for our value.
1368 // Use standard JPEG luminance DC table (simplified).
1369 out.push(0xFF);
1370 out.push(MARKER_DHT);
1371
1372 // DC table: we define a minimal table that can encode category 0 and a few others.
1373 // Cat 0: code 00 (2 bits) -> value 0 (zero diff)
1374 // Cat 1: code 010 (3 bits) -> 1 additional bit for value +-1
1375 // Cat 2: code 011 (3 bits) -> 2 additional bits
1376 // Cat 3: code 100 (3 bits) -> 3 additional bits
1377 // Cat 4: code 101 (3 bits) -> 4 additional bits
1378 // Cat 5: code 110 (3 bits) -> 5 additional bits
1379 // Cat 6: code 1110 (4 bits) -> 6 additional bits
1380 // Cat 7: code 11110 (5 bits) -> 7 additional bits
1381 // Cat 8: code 111110 (6 bits) -> 8 additional bits
1382 let dc_counts: [u8; 16] = [0, 1, 5, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
1383 let dc_symbols: &[u8] = &[0, 1, 2, 3, 4, 5, 6, 7, 8];
1384 let dc_total: usize = dc_counts.iter().map(|&c| c as usize).sum();
1385 let dht_dc_len = 2 + 1 + 16 + dc_total;
1386 out.push((dht_dc_len >> 8) as u8);
1387 out.push(dht_dc_len as u8);
1388 out.push(0x00); // DC table, id 0
1389 for &c in &dc_counts {
1390 out.push(c);
1391 }
1392 for &s in dc_symbols {
1393 out.push(s);
1394 }
1395
1396 // AC table 0: minimal - only EOB (0x00)
1397 // EOB: code 0 (shortest possible). Use: 1 code of length 1 = symbol 0x00
1398 // But we also need ZRL potentially. For DC-only block, EOB is all we need.
1399 out.push(0xFF);
1400 out.push(MARKER_DHT);
1401 let ac_counts: [u8; 16] = [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
1402 let ac_symbols: &[u8] = &[0x00]; // EOB
1403 let ac_total: usize = ac_counts.iter().map(|&c| c as usize).sum();
1404 let dht_ac_len = 2 + 1 + 16 + ac_total;
1405 out.push((dht_ac_len >> 8) as u8);
1406 out.push(dht_ac_len as u8);
1407 out.push(0x10); // AC table, id 0
1408 for &c in &ac_counts {
1409 out.push(c);
1410 }
1411 for &s in ac_symbols {
1412 out.push(s);
1413 }
1414
1415 // SOS
1416 out.push(0xFF);
1417 out.push(MARKER_SOS);
1418 let sos_len: u16 = 2 + 1 + 2 + 3;
1419 out.push((sos_len >> 8) as u8);
1420 out.push(sos_len as u8);
1421 out.push(1); // 1 component in scan
1422 out.push(1); // component selector = 1
1423 out.push(0x00); // DC table 0, AC table 0
1424 out.push(0); // Ss
1425 out.push(63); // Se
1426 out.push(0); // Ah/Al
1427
1428 // Entropy-coded data: encode DC category + value, then AC EOB
1429 // DC: for dc_value, we need the category and then the magnitude bits
1430 let (cat, magnitude) = if dc_value == 0 {
1431 (0u8, 0u16)
1432 } else {
1433 let abs_val = dc_value.unsigned_abs() as u16;
1434 let cat = 16 - abs_val.leading_zeros() as u8;
1435 let mag = if dc_value > 0 {
1436 dc_value as u16
1437 } else {
1438 ((1u16 << cat) - 1) - abs_val
1439 };
1440 (cat, mag)
1441 };
1442
1443 // Now encode: DC Huffman code for category, then `cat` magnitude bits, then AC EOB code
1444 // DC codes per our table:
1445 // cat 0: 00 (2 bits)
1446 // cat 1: 010 (3 bits)
1447 // cat 2: 011 (3 bits)
1448 // cat 3: 100 (3 bits)
1449 // cat 4: 101 (3 bits)
1450 // cat 5: 110 (3 bits)
1451 // cat 6: 1110 (4 bits)
1452 // cat 7: 11110 (5 bits)
1453 // cat 8: 111110 (6 bits)
1454 // AC EOB: 0 (1 bit)
1455
1456 // Build the DC Huffman code for our category
1457 let dc_huff_table = HuffTable::build(dc_counts, dc_symbols.to_vec());
1458 let mut dc_code = 0u32;
1459 let mut dc_code_len = 0u8;
1460 {
1461 // Find the code for our category symbol
1462 let mut code = 0u32;
1463 let mut si = 0usize;
1464 for i in 0..16 {
1465 for _ in 0..dc_counts[i] {
1466 if dc_huff_table.symbols[si] == cat {
1467 dc_code = code;
1468 dc_code_len = (i + 1) as u8;
1469 }
1470 code += 1;
1471 si += 1;
1472 }
1473 code <<= 1;
1474 }
1475 }
1476
1477 // Assemble bits: dc_code (dc_code_len bits) + magnitude (cat bits) + EOB (1 bit = 0)
1478 let mut bits = 0u64;
1479 let mut bp = 0u32;
1480
1481 // Write dc_code MSB first
1482 bits |= (dc_code as u64) << (64 - dc_code_len as u32 - bp);
1483 bp += dc_code_len as u32;
1484
1485 // Write magnitude bits MSB first
1486 if cat > 0 {
1487 bits |= (magnitude as u64) << (64 - cat as u32 - bp);
1488 bp += cat as u32;
1489 }
1490
1491 // Write AC EOB = 0 (1 bit)
1492 // Already 0, just advance
1493 bp += 1;
1494
1495 // Convert to bytes
1496 let byte_count = (bp + 7) / 8;
1497 for i in 0..byte_count {
1498 let b = ((bits >> (64 - 8 - i * 8)) & 0xFF) as u8;
1499 out.push(b);
1500 // Byte-stuff if needed
1501 if b == 0xFF {
1502 out.push(0x00);
1503 }
1504 }
1505
1506 // EOI
1507 out.push(0xFF);
1508 out.push(MARKER_EOI);
1509
1510 out
1511 }
1512
1513 #[test]
1514 fn decode_grayscale_8x8_dc_zero() {
1515 let jpeg = build_minimal_grayscale_jpeg(0);
1516 let img = decode_jpeg(&jpeg).unwrap();
1517 assert_eq!(img.width, 8);
1518 assert_eq!(img.height, 8);
1519 // DC=0, quant=1, so coefficient is 0, IDCT of zero block = 128 everywhere
1520 for i in 0..64 {
1521 let r = img.data[i * 4];
1522 let g = img.data[i * 4 + 1];
1523 let b = img.data[i * 4 + 2];
1524 let a = img.data[i * 4 + 3];
1525 assert_eq!(r, 128, "pixel {i}: R should be 128, got {r}");
1526 assert_eq!(g, 128);
1527 assert_eq!(b, 128);
1528 assert_eq!(a, 255);
1529 }
1530 }
1531
1532 #[test]
1533 fn decode_grayscale_8x8_dc_positive() {
1534 let jpeg = build_minimal_grayscale_jpeg(64);
1535 let img = decode_jpeg(&jpeg).unwrap();
1536 assert_eq!(img.width, 8);
1537 assert_eq!(img.height, 8);
1538 // DC=64, IDCT of DC-only -> 64/8 + 128 = 136
1539 let expected = 128 + 64 / 8;
1540 for i in 0..64 {
1541 let r = img.data[i * 4];
1542 assert!(
1543 (r as i32 - expected).unsigned_abs() <= 1,
1544 "pixel {i}: expected ~{expected}, got {r}"
1545 );
1546 }
1547 }
1548
1549 #[test]
1550 fn decode_grayscale_8x8_dc_negative() {
1551 let jpeg = build_minimal_grayscale_jpeg(-64);
1552 let img = decode_jpeg(&jpeg).unwrap();
1553 assert_eq!(img.width, 8);
1554 assert_eq!(img.height, 8);
1555 // DC=-64, IDCT of DC-only -> -64/8 + 128 = 120
1556 let expected = 128 - 64 / 8;
1557 for i in 0..64 {
1558 let r = img.data[i * 4];
1559 assert!(
1560 (r as i32 - expected).unsigned_abs() <= 1,
1561 "pixel {i}: expected ~{expected}, got {r}"
1562 );
1563 }
1564 }
1565
1566 // -- Error cases --
1567
1568 #[test]
1569 fn error_missing_soi() {
1570 let data = [0x00, 0x00];
1571 assert!(decode_jpeg(&data).is_err());
1572 }
1573
1574 #[test]
1575 fn error_too_short() {
1576 let data = [0xFF];
1577 assert!(decode_jpeg(&data).is_err());
1578 }
1579
1580 #[test]
1581 fn error_no_frame() {
1582 // SOI then EOI with no frame
1583 let data = [0xFF, MARKER_SOI, 0xFF, MARKER_EOI];
1584 let err = decode_jpeg(&data).unwrap_err();
1585 assert!(matches!(err, ImageError::Decode(_)));
1586 }
1587
1588 #[test]
1589 fn error_progressive_rejected() {
1590 let mut data = vec![0xFF, MARKER_SOI, 0xFF, 0xC2]; // SOF2 = progressive
1591 // Minimal SOF2 header
1592 data.extend_from_slice(&[0, 11, 8, 0, 8, 0, 8, 1, 1, 0x11, 0]);
1593 data.push(0xFF);
1594 data.push(MARKER_EOI);
1595 let err = decode_jpeg(&data).unwrap_err();
1596 match err {
1597 ImageError::Decode(msg) => assert!(msg.contains("baseline"), "got: {msg}"),
1598 _ => panic!("expected Decode error"),
1599 }
1600 }
1601
1602 // -- Quantization table parsing --
1603
1604 #[test]
1605 fn parse_dqt_8bit() {
1606 let mut data = vec![0, 67]; // length = 67
1607 data.push(0x00); // precision=0, table_id=0
1608 for i in 0..64u8 {
1609 data.push(i + 1);
1610 }
1611 let mut reader = JpegReader::new(&data);
1612 let mut tables: [Option<QuantTable>; 4] = [None, None, None, None];
1613 parse_dqt(&mut reader, &mut tables).unwrap();
1614 assert!(tables[0].is_some());
1615 assert_eq!(tables[0].as_ref().unwrap().values[0], 1);
1616 assert_eq!(tables[0].as_ref().unwrap().values[63], 64);
1617 }
1618
1619 // -- Huffman table parsing --
1620
1621 #[test]
1622 fn parse_dht_roundtrip() {
1623 let mut data = Vec::new();
1624 let counts: [u8; 16] = [0, 1, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
1625 let symbols: &[u8] = &[0, 1, 2, 3, 4, 5];
1626 let total: usize = counts.iter().map(|&c| c as usize).sum();
1627 let len = 2 + 1 + 16 + total;
1628 data.push((len >> 8) as u8);
1629 data.push(len as u8);
1630 data.push(0x00); // DC, id 0
1631 data.extend_from_slice(&counts);
1632 data.extend_from_slice(symbols);
1633
1634 let mut reader = JpegReader::new(&data);
1635 let mut dc_tables: [Option<HuffTable>; 4] = [None, None, None, None];
1636 let mut ac_tables: [Option<HuffTable>; 4] = [None, None, None, None];
1637 parse_dht(&mut reader, &mut dc_tables, &mut ac_tables).unwrap();
1638 assert!(dc_tables[0].is_some());
1639 assert_eq!(dc_tables[0].as_ref().unwrap().symbols, symbols);
1640 }
1641
1642 // -- Frame header parsing --
1643
1644 #[test]
1645 fn parse_sof0_basic() {
1646 let mut data = Vec::new();
1647 let len: u16 = 8 + 3;
1648 data.push((len >> 8) as u8);
1649 data.push(len as u8);
1650 data.push(8); // precision
1651 data.push(0);
1652 data.push(16); // height=16
1653 data.push(0);
1654 data.push(16); // width=16
1655 data.push(1); // 1 component
1656 data.push(1); // id=1
1657 data.push(0x11); // H=1, V=1
1658 data.push(0); // quant table 0
1659
1660 let mut reader = JpegReader::new(&data);
1661 let frame = parse_sof0(&mut reader).unwrap();
1662 assert_eq!(frame.width, 16);
1663 assert_eq!(frame.height, 16);
1664 assert_eq!(frame.components.len(), 1);
1665 assert_eq!(frame.h_max, 1);
1666 assert_eq!(frame.v_max, 1);
1667 }
1668
1669 #[test]
1670 fn parse_sof0_three_components() {
1671 let mut data = Vec::new();
1672 let len: u16 = 8 + 9; // 3 components * 3 bytes each
1673 data.push((len >> 8) as u8);
1674 data.push(len as u8);
1675 data.push(8);
1676 data.push(0);
1677 data.push(32); // height=32
1678 data.push(0);
1679 data.push(32); // width=32
1680 data.push(3); // 3 components
1681 // Y: H=2, V=2
1682 data.push(1);
1683 data.push(0x22);
1684 data.push(0);
1685 // Cb: H=1, V=1
1686 data.push(2);
1687 data.push(0x11);
1688 data.push(1);
1689 // Cr: H=1, V=1
1690 data.push(3);
1691 data.push(0x11);
1692 data.push(1);
1693
1694 let mut reader = JpegReader::new(&data);
1695 let frame = parse_sof0(&mut reader).unwrap();
1696 assert_eq!(frame.width, 32);
1697 assert_eq!(frame.height, 32);
1698 assert_eq!(frame.components.len(), 3);
1699 assert_eq!(frame.h_max, 2);
1700 assert_eq!(frame.v_max, 2);
1701 assert_eq!(frame.components[0].h_sample, 2);
1702 assert_eq!(frame.components[0].v_sample, 2);
1703 assert_eq!(frame.components[1].h_sample, 1);
1704 }
1705
1706 // -- Clamp --
1707
1708 #[test]
1709 fn clamp_values() {
1710 assert_eq!(clamp_to_u8(-10), 0);
1711 assert_eq!(clamp_to_u8(0), 0);
1712 assert_eq!(clamp_to_u8(128), 128);
1713 assert_eq!(clamp_to_u8(255), 255);
1714 assert_eq!(clamp_to_u8(300), 255);
1715 }
1716}