// Copyright 2020 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // //////////////////////////////////////////////////////////////////////////////// package com.google.crypto.tink.jwt; import static com.google.crypto.tink.internal.Util.UTF_8; import com.google.crypto.tink.proto.OutputPrefixType; import com.google.crypto.tink.subtle.Base64; import com.google.gson.JsonObject; import java.nio.ByteBuffer; import java.nio.charset.CharacterCodingException; import java.nio.charset.CharsetDecoder; import java.security.InvalidAlgorithmParameterException; import java.util.Optional; final class JwtFormat { static class Parts { String unsignedCompact; byte[] signatureOrMac; String header; String payload; Parts( String unsignedCompact, byte[] signatureOrMac, String header, String payload) { this.unsignedCompact = unsignedCompact; this.signatureOrMac = signatureOrMac; this.header = header; this.payload = payload; } } private JwtFormat() {} static boolean isValidUrlsafeBase64Char(char c) { return (((c >= 'a') && (c <= 'z')) || ((c >= 'A') && (c <= 'Z')) || ((c >= '0') && (c <= '9')) || ((c == '-') || (c == '_'))); } // We need this validation, since String(data, UTF_8) ignores invalid characters. static void validateUtf8(byte[] data) throws JwtInvalidException { CharsetDecoder decoder = UTF_8.newDecoder(); try { decoder.decode(ByteBuffer.wrap(data)); } catch (CharacterCodingException ex) { throw new JwtInvalidException(ex.getMessage()); } } static byte[] strictUrlSafeDecode(String encodedData) throws JwtInvalidException { for (int i = 0; i < encodedData.length(); i++) { char c = encodedData.charAt(i); if (!isValidUrlsafeBase64Char(c)) { throw new JwtInvalidException("invalid encoding"); } } try { return Base64.urlSafeDecode(encodedData); } catch (IllegalArgumentException ex) { throw new JwtInvalidException("invalid encoding: " + ex); } } private static void validateAlgorithm(String algo) throws InvalidAlgorithmParameterException { switch (algo) { case "HS256": case "HS384": case "HS512": case "ES256": case "ES384": case "ES512": case "RS256": case "RS384": case "RS512": case "PS256": case "PS384": case "PS512": return; default: throw new InvalidAlgorithmParameterException("invalid algorithm: " + algo); } } static String createHeader(String algorithm, Optional typeHeader, Optional kid) throws InvalidAlgorithmParameterException { validateAlgorithm(algorithm); JsonObject header = new JsonObject(); if (kid.isPresent()) { header.addProperty(JwtNames.HEADER_KEY_ID, kid.get()); } header.addProperty(JwtNames.HEADER_ALGORITHM, algorithm); if (typeHeader.isPresent()) { header.addProperty(JwtNames.HEADER_TYPE, typeHeader.get()); } return Base64.urlSafeEncode(header.toString().getBytes(UTF_8)); } private static void validateKidInHeader(String expectedKid, JsonObject parsedHeader) throws JwtInvalidException { String kid = getStringHeader(parsedHeader, JwtNames.HEADER_KEY_ID); if (!kid.equals(expectedKid)) { throw new JwtInvalidException("invalid kid in header"); } } /** * Validates the parsed header. * * tinkKid should only be set for keys with output prefix type TINK. customKid should only * be set for keys with output prefix type RAW. They should not be set at the same time. */ static void validateHeader( String expectedAlgorithm, Optional tinkKid, Optional customKid, JsonObject parsedHeader) throws InvalidAlgorithmParameterException, JwtInvalidException { validateAlgorithm(expectedAlgorithm); String algorithm = getStringHeader(parsedHeader, JwtNames.HEADER_ALGORITHM); if (!algorithm.equals(expectedAlgorithm)) { throw new InvalidAlgorithmParameterException( String.format( "invalid algorithm; expected %s, got %s", expectedAlgorithm, algorithm)); } if (parsedHeader.has(JwtNames.HEADER_CRITICAL)) { throw new JwtInvalidException("all tokens with crit headers are rejected"); } if (tinkKid.isPresent() && customKid.isPresent()) { throw new JwtInvalidException("custom_kid can only be set for RAW keys."); } boolean headerHasKid = parsedHeader.has(JwtNames.HEADER_KEY_ID); if (tinkKid.isPresent()) { if (!headerHasKid) { // for output prefix type TINK, the kid header is required. throw new JwtInvalidException("missing kid in header"); } validateKidInHeader(tinkKid.get(), parsedHeader); } if (customKid.isPresent() && headerHasKid) { // for output prefix type RAW, the kid header is not required, even if custom kid is set. validateKidInHeader(customKid.get(), parsedHeader); } // Ignore all other headers } static Optional getTypeHeader(JsonObject header) throws JwtInvalidException { if (header.has(JwtNames.HEADER_TYPE)) { return Optional.of(getStringHeader(header, JwtNames.HEADER_TYPE)); } return Optional.empty(); } private static String getStringHeader(JsonObject header, String name) throws JwtInvalidException { if (!header.has(name)) { throw new JwtInvalidException("header " + name + " does not exist"); } if (!header.get(name).isJsonPrimitive() || !header.get(name).getAsJsonPrimitive().isString()) { throw new JwtInvalidException("header " + name + " is not a string"); } return header.get(name).getAsString(); } static String decodeHeader(String headerStr) throws JwtInvalidException { byte[] data = strictUrlSafeDecode(headerStr); validateUtf8(data); return new String(data, UTF_8); } static String encodePayload(String jsonPayload) { return Base64.urlSafeEncode(jsonPayload.getBytes(UTF_8)); } static String decodePayload(String payloadStr) throws JwtInvalidException { byte[] data = strictUrlSafeDecode(payloadStr); validateUtf8(data); return new String(data, UTF_8); } static String encodeSignature(byte[] signature) { return Base64.urlSafeEncode(signature); } static byte[] decodeSignature(String signatureStr) throws JwtInvalidException { return strictUrlSafeDecode(signatureStr); } static Optional getKid(int keyId, OutputPrefixType prefix) throws JwtInvalidException { if (prefix == OutputPrefixType.RAW) { return Optional.empty(); } if (prefix == OutputPrefixType.TINK) { byte[] bigEndianKeyId = ByteBuffer.allocate(4).putInt(keyId).array(); return Optional.of(Base64.urlSafeEncode(bigEndianKeyId)); } throw new JwtInvalidException("unsupported output prefix type"); } static Optional getKeyId(String kid) { byte[] encodedKeyId = Base64.urlSafeDecode(kid); if (encodedKeyId.length != 4) { return Optional.empty(); } return Optional.of(ByteBuffer.wrap(encodedKeyId).getInt()); } static Parts splitSignedCompact(String signedCompact) throws JwtInvalidException { validateASCII(signedCompact); int sigPos = signedCompact.lastIndexOf('.'); if (sigPos < 0) { throw new JwtInvalidException( "only tokens in JWS compact serialization format are supported"); } String unsignedCompact = signedCompact.substring(0, sigPos); String encodedMac = signedCompact.substring(sigPos + 1); byte[] mac = decodeSignature(encodedMac); int payloadPos = unsignedCompact.indexOf('.'); if (payloadPos < 0) { throw new JwtInvalidException( "only tokens in JWS compact serialization format are supported"); } String encodedHeader = unsignedCompact.substring(0, payloadPos); String encodedPayload = unsignedCompact.substring(payloadPos + 1); if (encodedPayload.indexOf('.') > 0) { throw new JwtInvalidException( "only tokens in JWS compact serialization format are supported"); } String header = decodeHeader(encodedHeader); String payload = decodePayload(encodedPayload); return new Parts(unsignedCompact, mac, header, payload); } static String createUnsignedCompact(String algorithm, Optional kid, RawJwt rawJwt) throws InvalidAlgorithmParameterException, JwtInvalidException { String jsonPayload = rawJwt.getJsonPayload(); Optional typeHeader = rawJwt.hasTypeHeader() ? Optional.of(rawJwt.getTypeHeader()) : Optional.empty(); return createHeader(algorithm, typeHeader, kid) + "." + encodePayload(jsonPayload); } static String createSignedCompact(String unsignedCompact, byte[] signature) { return unsignedCompact + "." + encodeSignature(signature); } static void validateASCII(String data) throws JwtInvalidException { for (int i = 0; i < data.length(); i++) { char c = data.charAt(i); if ((c & 0x80) > 0) { throw new JwtInvalidException("Non ascii character"); } } } }