1// Copyright 2024 The Gitea Authors. All rights reserved.
2// SPDX-License-Identifier: MIT
3
4// Package zstd provides a high-level API for reading and writing zstd-compressed data.
5// It supports both regular and seekable zstd streams.
6// It's not a new wheel, but a wrapper around the zstd and zstd-seekable-format-go packages.
7package zstd
8
9import (
10 "errors"
11 "io"
12
13 seekable "github.com/SaveTheRbtz/zstd-seekable-format-go/pkg"
14 "github.com/klauspost/compress/zstd"
15)
16
17type Writer zstd.Encoder
18
19var _ io.WriteCloser = (*Writer)(nil)
20
21// NewWriter returns a new zstd writer.
22func NewWriter(w io.Writer, opts ...WriterOption) (*Writer, error) {
23 zstdW, err := zstd.NewWriter(w, opts...)
24 if err != nil {
25 return nil, err
26 }
27 return (*Writer)(zstdW), nil
28}
29
30func (w *Writer) Write(p []byte) (int, error) {
31 return (*zstd.Encoder)(w).Write(p)
32}
33
34func (w *Writer) Close() error {
35 return (*zstd.Encoder)(w).Close()
36}
37
38type Reader zstd.Decoder
39
40var _ io.ReadCloser = (*Reader)(nil)
41
42// NewReader returns a new zstd reader.
43func NewReader(r io.Reader, opts ...ReaderOption) (*Reader, error) {
44 zstdR, err := zstd.NewReader(r, opts...)
45 if err != nil {
46 return nil, err
47 }
48 return (*Reader)(zstdR), nil
49}
50
51func (r *Reader) Read(p []byte) (int, error) {
52 return (*zstd.Decoder)(r).Read(p)
53}
54
55func (r *Reader) Close() error {
56 (*zstd.Decoder)(r).Close() // no error returned
57 return nil
58}
59
60type SeekableWriter struct {
61 buf []byte
62 n int
63 w seekable.Writer
64}
65
66var _ io.WriteCloser = (*SeekableWriter)(nil)
67
68// NewSeekableWriter returns a zstd writer to compress data to seekable format.
69// blockSize is an important parameter, it should be decided according to the actual business requirements.
70// If it's too small, the compression ratio could be very bad, even no compression at all.
71// If it's too large, it could cost more traffic when reading the data partially from underlying storage.
72func NewSeekableWriter(w io.Writer, blockSize int, opts ...WriterOption) (*SeekableWriter, error) {
73 zstdW, err := zstd.NewWriter(nil, opts...)
74 if err != nil {
75 return nil, err
76 }
77
78 seekableW, err := seekable.NewWriter(w, zstdW)
79 if err != nil {
80 return nil, err
81 }
82
83 return &SeekableWriter{
84 buf: make([]byte, blockSize),
85 w: seekableW,
86 }, nil
87}
88
89func (w *SeekableWriter) Write(p []byte) (int, error) {
90 written := 0
91 for len(p) > 0 {
92 n := copy(w.buf[w.n:], p)
93 w.n += n
94 written += n
95 p = p[n:]
96
97 if w.n == len(w.buf) {
98 if _, err := w.w.Write(w.buf); err != nil {
99 return written, err
100 }
101 w.n = 0
102 }
103 }
104 return written, nil
105}
106
107func (w *SeekableWriter) Close() error {
108 if w.n > 0 {
109 if _, err := w.w.Write(w.buf[:w.n]); err != nil {
110 return err
111 }
112 }
113 return w.w.Close()
114}
115
116type SeekableReader struct {
117 r seekable.Reader
118 c func() error
119}
120
121var _ io.ReadSeekCloser = (*SeekableReader)(nil)
122
123// NewSeekableReader returns a zstd reader to decompress data from seekable format.
124func NewSeekableReader(r io.ReadSeeker, opts ...ReaderOption) (*SeekableReader, error) {
125 zstdR, err := zstd.NewReader(nil, opts...)
126 if err != nil {
127 return nil, err
128 }
129
130 seekableR, err := seekable.NewReader(r, zstdR)
131 if err != nil {
132 return nil, err
133 }
134
135 ret := &SeekableReader{
136 r: seekableR,
137 }
138 if closer, ok := r.(io.Closer); ok {
139 ret.c = closer.Close
140 }
141
142 return ret, nil
143}
144
145func (r *SeekableReader) Read(p []byte) (int, error) {
146 return r.r.Read(p)
147}
148
149func (r *SeekableReader) Seek(offset int64, whence int) (int64, error) {
150 return r.r.Seek(offset, whence)
151}
152
153func (r *SeekableReader) Close() error {
154 return errors.Join(
155 func() error {
156 if r.c != nil {
157 return r.c()
158 }
159 return nil
160 }(),
161 r.r.Close(),
162 )
163}