aboutsummaryrefslogtreecommitdiff
path: root/okio/src/jvmMain/kotlin/okio/CipherSink.kt
blob: aa482842ba75f5562617ad1ab869891d95a5cac4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
/*
 * Copyright (C) 2020 Square, Inc. and others.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package okio

import java.io.IOException
import javax.crypto.Cipher

class CipherSink(
  private val sink: BufferedSink,
  val cipher: Cipher
) : Sink {
  private val blockSize = cipher.blockSize
  private var closed = false

  init {
    // Require block cipher
    require(blockSize > 0) { "Block cipher required $cipher" }
  }

  @Throws(IOException::class)
  override fun write(source: Buffer, byteCount: Long) {
    checkOffsetAndCount(source.size, 0, byteCount)
    check(!closed) { "closed" }

    var remaining = byteCount
    while (remaining > 0) {
      val size = update(source, remaining)
      remaining -= size
    }
  }

  private fun update(source: Buffer, remaining: Long): Int {
    val head = source.head!!
    var size = minOf(remaining, head.limit - head.pos).toInt()
    val buffer = sink.buffer

    // Shorten input until output is guaranteed to fit within a segment
    var outputSize = cipher.getOutputSize(size)
    while (outputSize > Segment.SIZE) {
      check(size > blockSize) { "Unexpected output size $outputSize for input size $size" }
      size -= blockSize
      outputSize = cipher.getOutputSize(size)
    }
    val s = buffer.writableSegment(outputSize)

    val ciphered = cipher.update(head.data, head.pos, size, s.data, s.limit)

    s.limit += ciphered
    buffer.size += ciphered

    // We allocated a tail segment, but didn't end up needing it. Recycle!
    if (s.pos == s.limit) {
      buffer.head = s.pop()
      SegmentPool.recycle(s)
    }

    sink.emitCompleteSegments()

    // Mark those bytes as read.
    source.size -= size
    head.pos += size
    if (head.pos == head.limit) {
      source.head = head.pop()
      SegmentPool.recycle(head)
    }

    return size
  }

  override fun flush() = sink.flush()

  override fun timeout() = sink.timeout()

  @Throws(IOException::class)
  override fun close() {
    if (closed) return
    closed = true

    var thrown = doFinal()

    try {
      sink.close()
    } catch (e: Throwable) {
      if (thrown == null) thrown = e
    }

    if (thrown != null) throw thrown
  }

  private fun doFinal(): Throwable? {
    val outputSize = cipher.getOutputSize(0)
    if (outputSize == 0) return null

    var thrown: Throwable? = null
    val buffer = sink.buffer

    // For block cipher, output size cannot exceed block size in doFinal
    val s = buffer.writableSegment(outputSize)

    try {
      val ciphered = cipher.doFinal(s.data, s.limit)

      s.limit += ciphered
      buffer.size += ciphered
    } catch (e: Throwable) {
      thrown = e
    }

    if (s.pos == s.limit) {
      buffer.head = s.pop()
      SegmentPool.recycle(s)
    }

    return thrown
  }
}