aboutsummaryrefslogtreecommitdiff
path: root/okio/src/jvmMain/kotlin/okio/Throttler.kt
blob: 859fb71606de97cc49b01310af31a196574a78a3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
/*
 * Copyright (C) 2018 Square, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package okio

import java.io.IOException
import java.io.InterruptedIOException
import java.util.concurrent.locks.Condition
import java.util.concurrent.locks.ReentrantLock
import kotlin.concurrent.withLock

/**
 * Enables limiting of Source and Sink throughput. Attach to this throttler via [source] and [sink]
 * and set the desired throughput via [bytesPerSecond]. Multiple Sources and Sinks can be
 * attached to a single Throttler and they will be throttled as a group, where their combined
 * throughput will not exceed the desired throughput. The same Source or Sink can be attached to
 * multiple Throttlers and its throughput will not exceed the desired throughput of any of the
 * Throttlers.
 *
 * This class has these tuning parameters:
 *
 *  * `bytesPerSecond`: Maximum sustained throughput. Use 0 for no limit.
 *  * `waitByteCount`: When the requested byte count is greater than this many bytes and isn't
 *    immediately available, only wait until we can allocate at least this many bytes. Use this to
 *    set the ideal byte count during sustained throughput.
 *  * `maxByteCount`: Maximum number of bytes to allocate on any call. This is also the number of
 *    bytes that will be returned before any waiting.
 */
class Throttler internal constructor(
  /**
   * The nanoTime that we've consumed all bytes through. This is never greater than the current
   * nanoTime plus nanosForMaxByteCount.
   */
  private var allocatedUntil: Long,
) {
  private var bytesPerSecond: Long = 0L
  private var waitByteCount: Long = 8 * 1024 // 8 KiB.
  private var maxByteCount: Long = 256 * 1024 // 256 KiB.

  val lock: ReentrantLock = ReentrantLock()
  val condition: Condition = lock.newCondition()

  constructor() : this(allocatedUntil = System.nanoTime())

  /** Sets the rate at which bytes will be allocated. Use 0 for no limit. */
  @JvmOverloads
  fun bytesPerSecond(
    bytesPerSecond: Long,
    waitByteCount: Long = this.waitByteCount,
    maxByteCount: Long = this.maxByteCount,
  ) {
    lock.withLock {
      require(bytesPerSecond >= 0)
      require(waitByteCount > 0)
      require(maxByteCount >= waitByteCount)

      this.bytesPerSecond = bytesPerSecond
      this.waitByteCount = waitByteCount
      this.maxByteCount = maxByteCount
      condition.signalAll()
    }
  }

  /**
   * Take up to `byteCount` bytes, waiting if necessary. Returns the number of bytes that were
   * taken.
   */
  internal fun take(byteCount: Long): Long {
    require(byteCount > 0)

    lock.withLock {
      while (true) {
        val now = System.nanoTime()
        val byteCountOrWaitNanos = byteCountOrWaitNanos(now, byteCount)
        if (byteCountOrWaitNanos >= 0) return byteCountOrWaitNanos
        condition.awaitNanos(-byteCountOrWaitNanos)
      }
    }
  }

  /**
   * Returns the byte count to take immediately or -1 times the number of nanos to wait until the
   * next attempt. If the returned value is negative it should be interpreted as a duration in
   * nanos; if it is positive it should be interpreted as a byte count.
   */
  internal fun byteCountOrWaitNanos(now: Long, byteCount: Long): Long {
    if (bytesPerSecond == 0L) return byteCount // No limits.

    val idleInNanos = maxOf(allocatedUntil - now, 0L)
    val immediateBytes = maxByteCount - idleInNanos.nanosToBytes()

    // Fulfill the entire request without waiting.
    if (immediateBytes >= byteCount) {
      allocatedUntil = now + idleInNanos + byteCount.bytesToNanos()
      return byteCount
    }

    // Fulfill a big-enough block without waiting.
    if (immediateBytes >= waitByteCount) {
      allocatedUntil = now + maxByteCount.bytesToNanos()
      return immediateBytes
    }

    // Looks like we'll need to wait until we can take the minimum required bytes.
    val minByteCount = minOf(waitByteCount, byteCount)
    val minWaitNanos = idleInNanos + (minByteCount - maxByteCount).bytesToNanos()

    // But if the wait duration truncates to zero nanos after division, don't wait.
    if (minWaitNanos == 0L) {
      allocatedUntil = now + maxByteCount.bytesToNanos()
      return minByteCount
    }

    return -minWaitNanos
  }

  private fun Long.nanosToBytes() = this * bytesPerSecond / 1_000_000_000L

  private fun Long.bytesToNanos() = this * 1_000_000_000L / bytesPerSecond

  /** Create a Source which honors this Throttler.  */
  fun source(source: Source): Source {
    return object : ForwardingSource(source) {
      override fun read(sink: Buffer, byteCount: Long): Long {
        try {
          val toRead = take(byteCount)
          return super.read(sink, toRead)
        } catch (e: InterruptedException) {
          Thread.currentThread().interrupt()
          throw InterruptedIOException("interrupted")
        }
      }
    }
  }

  /** Create a Sink which honors this Throttler.  */
  fun sink(sink: Sink): Sink {
    return object : ForwardingSink(sink) {
      @Throws(IOException::class)
      override fun write(source: Buffer, byteCount: Long) {
        try {
          var remaining = byteCount
          while (remaining > 0L) {
            val toWrite = take(remaining)
            super.write(source, toWrite)
            remaining -= toWrite
          }
        } catch (e: InterruptedException) {
          Thread.currentThread().interrupt()
          throw InterruptedIOException("interrupted")
        }
      }
    }
  }
}