diff options
Diffstat (limited to 'bcprov/src/main/java/org/bouncycastle/crypto/engines/Salsa20Engine.java')
-rw-r--r-- | bcprov/src/main/java/org/bouncycastle/crypto/engines/Salsa20Engine.java | 326 |
1 files changed, 242 insertions, 84 deletions
diff --git a/bcprov/src/main/java/org/bouncycastle/crypto/engines/Salsa20Engine.java b/bcprov/src/main/java/org/bouncycastle/crypto/engines/Salsa20Engine.java index 2d6140da..13aada9e 100644 --- a/bcprov/src/main/java/org/bouncycastle/crypto/engines/Salsa20Engine.java +++ b/bcprov/src/main/java/org/bouncycastle/crypto/engines/Salsa20Engine.java @@ -4,17 +4,17 @@ import org.bouncycastle.crypto.CipherParameters; import org.bouncycastle.crypto.DataLengthException; import org.bouncycastle.crypto.MaxBytesExceededException; import org.bouncycastle.crypto.OutputLengthException; -import org.bouncycastle.crypto.StreamCipher; +import org.bouncycastle.crypto.SkippingStreamCipher; import org.bouncycastle.crypto.params.KeyParameter; import org.bouncycastle.crypto.params.ParametersWithIV; -import org.bouncycastle.crypto.util.Pack; +import org.bouncycastle.util.Pack; import org.bouncycastle.util.Strings; /** * Implementation of Daniel J. Bernstein's Salsa20 stream cipher, Snuffle 2005 */ public class Salsa20Engine - implements StreamCipher + implements SkippingStreamCipher { public final static int DEFAULT_ROUNDS = 20; @@ -33,7 +33,7 @@ public class Salsa20Engine */ private int index = 0; protected int[] engineState = new int[STATE_SIZE]; // state - protected int[] x = new int[STATE_SIZE] ; // internal buffer + protected int[] x = new int[STATE_SIZE] ; // internal buffer private byte[] keyStream = new byte[STATE_SIZE * 4]; // expanded state, 64 bytes private boolean initialised = false; @@ -96,15 +96,27 @@ public class Salsa20Engine + " bytes of IV"); } - if (!(ivParams.getParameters() instanceof KeyParameter)) + CipherParameters keyParam = ivParams.getParameters(); + if (keyParam == null) { - throw new IllegalArgumentException(getAlgorithmName() + " Init parameters must include a key"); - } + if (!initialised) + { + throw new IllegalStateException(getAlgorithmName() + " KeyParameter can not be null for first initialisation"); + } - KeyParameter key = (KeyParameter) ivParams.getParameters(); + setKey(null, iv); + } + else if (keyParam instanceof KeyParameter) + { + setKey(((KeyParameter)keyParam).getKey(), iv); + } + else + { + throw new IllegalArgumentException(getAlgorithmName() + " Init parameters must contain a KeyParameter (or null for re-init)"); + } - setKey(key.getKey(), iv); reset(); + initialised = true; } @@ -130,18 +142,38 @@ public class Salsa20Engine throw new MaxBytesExceededException("2^70 byte limit per IV; Change IV"); } + byte out = (byte)(keyStream[index]^in); + index = (index + 1) & 63; + if (index == 0) { - generateKeyStream(keyStream); advanceCounter(); + generateKeyStream(keyStream); } - byte out = (byte)(keyStream[index]^in); - index = (index + 1) & 63; - return out; } + protected void advanceCounter(long diff) + { + int hi = (int)(diff >>> 32); + int lo = (int)diff; + + if (hi > 0) + { + engineState[9] += hi; + } + + int oldState = engineState[8]; + + engineState[8] += lo; + + if (oldState != 0 && engineState[8] < oldState) + { + engineState[9]++; + } + } + protected void advanceCounter() { if (++engineState[8] == 0) @@ -150,7 +182,55 @@ public class Salsa20Engine } } - public void processBytes( + protected void retreatCounter(long diff) + { + int hi = (int)(diff >>> 32); + int lo = (int)diff; + + if (hi != 0) + { + if ((engineState[9] & 0xffffffffL) >= (hi & 0xffffffffL)) + { + engineState[9] -= hi; + } + else + { + throw new IllegalStateException("attempt to reduce counter past zero."); + } + } + + if ((engineState[8] & 0xffffffffL) >= (lo & 0xffffffffL)) + { + engineState[8] -= lo; + } + else + { + if (engineState[9] != 0) + { + --engineState[9]; + engineState[8] -= lo; + } + else + { + throw new IllegalStateException("attempt to reduce counter past zero."); + } + } + } + + protected void retreatCounter() + { + if (engineState[8] == 0 && engineState[9] == 0) + { + throw new IllegalStateException("attempt to reduce counter past zero."); + } + + if (--engineState[8] == -1) + { + --engineState[9]; + } + } + + public int processBytes( byte[] in, int inOff, int len, @@ -179,15 +259,82 @@ public class Salsa20Engine for (int i = 0; i < len; i++) { + out[i + outOff] = (byte)(keyStream[index] ^ in[i + inOff]); + index = (index + 1) & 63; + if (index == 0) { + advanceCounter(); generateKeyStream(keyStream); + } + } + + return len; + } + + public long skip(long numberOfBytes) + { + if (numberOfBytes >= 0) + { + long remaining = numberOfBytes; + + if (remaining >= 64) + { + long count = remaining / 64; + + advanceCounter(count); + + remaining -= count * 64; + } + + int oldIndex = index; + + index = (index + (int)remaining) & 63; + + if (index < oldIndex) + { advanceCounter(); } + } + else + { + long remaining = -numberOfBytes; - out[i+outOff] = (byte)(keyStream[index]^in[i+inOff]); - index = (index + 1) & 63; + if (remaining >= 64) + { + long count = remaining / 64; + + retreatCounter(count); + + remaining -= count * 64; + } + + for (long i = 0; i < remaining; i++) + { + if (index == 0) + { + retreatCounter(); + } + + index = (index - 1) & 63; + } } + + generateKeyStream(keyStream); + + return numberOfBytes; + } + + public long seekTo(long position) + { + reset(); + + return skip(position); + } + + public long getPosition() + { + return getCounter() * 64 + index; } public void reset() @@ -195,6 +342,13 @@ public class Salsa20Engine index = 0; resetLimitCounter(); resetCounter(); + + generateKeyStream(keyStream); + } + + protected long getCounter() + { + return ((long)engineState[9] << 32) | (engineState[8] & 0xffffffffL); } protected void resetCounter() @@ -204,43 +358,46 @@ public class Salsa20Engine protected void setKey(byte[] keyBytes, byte[] ivBytes) { - if ((keyBytes.length != 16) && (keyBytes.length != 32)) { - throw new IllegalArgumentException(getAlgorithmName() + " requires 128 bit or 256 bit key"); - } - - int offset = 0; - byte[] constants; - - // Key - engineState[1] = Pack.littleEndianToInt(keyBytes, 0); - engineState[2] = Pack.littleEndianToInt(keyBytes, 4); - engineState[3] = Pack.littleEndianToInt(keyBytes, 8); - engineState[4] = Pack.littleEndianToInt(keyBytes, 12); - - if (keyBytes.length == 32) - { - constants = sigma; - offset = 16; - } - else + if (keyBytes != null) { - constants = tau; - } + if ((keyBytes.length != 16) && (keyBytes.length != 32)) + { + throw new IllegalArgumentException(getAlgorithmName() + " requires 128 bit or 256 bit key"); + } - engineState[11] = Pack.littleEndianToInt(keyBytes, offset); - engineState[12] = Pack.littleEndianToInt(keyBytes, offset+4); - engineState[13] = Pack.littleEndianToInt(keyBytes, offset+8); - engineState[14] = Pack.littleEndianToInt(keyBytes, offset+12); + // Key + engineState[1] = Pack.littleEndianToInt(keyBytes, 0); + engineState[2] = Pack.littleEndianToInt(keyBytes, 4); + engineState[3] = Pack.littleEndianToInt(keyBytes, 8); + engineState[4] = Pack.littleEndianToInt(keyBytes, 12); - engineState[0 ] = Pack.littleEndianToInt(constants, 0); - engineState[5 ] = Pack.littleEndianToInt(constants, 4); - engineState[10] = Pack.littleEndianToInt(constants, 8); - engineState[15] = Pack.littleEndianToInt(constants, 12); + byte[] constants; + int offset; + if (keyBytes.length == 32) + { + constants = sigma; + offset = 16; + } + else + { + constants = tau; + offset = 0; + } + + engineState[11] = Pack.littleEndianToInt(keyBytes, offset); + engineState[12] = Pack.littleEndianToInt(keyBytes, offset + 4); + engineState[13] = Pack.littleEndianToInt(keyBytes, offset + 8); + engineState[14] = Pack.littleEndianToInt(keyBytes, offset + 12); + + engineState[0 ] = Pack.littleEndianToInt(constants, 0); + engineState[5 ] = Pack.littleEndianToInt(constants, 4); + engineState[10] = Pack.littleEndianToInt(constants, 8); + engineState[15] = Pack.littleEndianToInt(constants, 12); + } // IV engineState[6] = Pack.littleEndianToInt(ivBytes, 0); engineState[7] = Pack.littleEndianToInt(ivBytes, 4); - resetCounter(); } protected void generateKeyStream(byte[] output) @@ -253,18 +410,19 @@ public class Salsa20Engine * Salsa20 function * * @param input input data - * - * @return keystream */ public static void salsaCore(int rounds, int[] input, int[] x) { - if (input.length != 16) { + if (input.length != 16) + { throw new IllegalArgumentException(); } - if (x.length != 16) { + if (x.length != 16) + { throw new IllegalArgumentException(); } - if (rounds % 2 != 0) { + if (rounds % 2 != 0) + { throw new IllegalArgumentException("Number of rounds must be even"); } @@ -287,39 +445,39 @@ public class Salsa20Engine for (int i = rounds; i > 0; i -= 2) { - x04 ^= rotl((x00+x12), 7); - x08 ^= rotl((x04+x00), 9); - x12 ^= rotl((x08+x04),13); - x00 ^= rotl((x12+x08),18); - x09 ^= rotl((x05+x01), 7); - x13 ^= rotl((x09+x05), 9); - x01 ^= rotl((x13+x09),13); - x05 ^= rotl((x01+x13),18); - x14 ^= rotl((x10+x06), 7); - x02 ^= rotl((x14+x10), 9); - x06 ^= rotl((x02+x14),13); - x10 ^= rotl((x06+x02),18); - x03 ^= rotl((x15+x11), 7); - x07 ^= rotl((x03+x15), 9); - x11 ^= rotl((x07+x03),13); - x15 ^= rotl((x11+x07),18); - - x01 ^= rotl((x00+x03), 7); - x02 ^= rotl((x01+x00), 9); - x03 ^= rotl((x02+x01),13); - x00 ^= rotl((x03+x02),18); - x06 ^= rotl((x05+x04), 7); - x07 ^= rotl((x06+x05), 9); - x04 ^= rotl((x07+x06),13); - x05 ^= rotl((x04+x07),18); - x11 ^= rotl((x10+x09), 7); - x08 ^= rotl((x11+x10), 9); - x09 ^= rotl((x08+x11),13); - x10 ^= rotl((x09+x08),18); - x12 ^= rotl((x15+x14), 7); - x13 ^= rotl((x12+x15), 9); - x14 ^= rotl((x13+x12),13); - x15 ^= rotl((x14+x13),18); + x04 ^= rotl(x00 + x12, 7); + x08 ^= rotl(x04 + x00, 9); + x12 ^= rotl(x08 + x04, 13); + x00 ^= rotl(x12 + x08, 18); + x09 ^= rotl(x05 + x01, 7); + x13 ^= rotl(x09 + x05, 9); + x01 ^= rotl(x13 + x09, 13); + x05 ^= rotl(x01 + x13, 18); + x14 ^= rotl(x10 + x06, 7); + x02 ^= rotl(x14 + x10, 9); + x06 ^= rotl(x02 + x14, 13); + x10 ^= rotl(x06 + x02, 18); + x03 ^= rotl(x15 + x11, 7); + x07 ^= rotl(x03 + x15, 9); + x11 ^= rotl(x07 + x03, 13); + x15 ^= rotl(x11 + x07, 18); + + x01 ^= rotl(x00 + x03, 7); + x02 ^= rotl(x01 + x00, 9); + x03 ^= rotl(x02 + x01, 13); + x00 ^= rotl(x03 + x02, 18); + x06 ^= rotl(x05 + x04, 7); + x07 ^= rotl(x06 + x05, 9); + x04 ^= rotl(x07 + x06, 13); + x05 ^= rotl(x04 + x07, 18); + x11 ^= rotl(x10 + x09, 7); + x08 ^= rotl(x11 + x10, 9); + x09 ^= rotl(x08 + x11, 13); + x10 ^= rotl(x09 + x08, 18); + x12 ^= rotl(x15 + x14, 7); + x13 ^= rotl(x12 + x15, 9); + x14 ^= rotl(x13 + x12, 13); + x15 ^= rotl(x14 + x13, 18); } x[ 0] = x00 + input[ 0]; |