JwtAuthenticationTokenSecurityContextFactory.java

/*
 * Copyright 2019 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.bremersee.test.security.authentication;

import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.PlainJWT;
import java.time.Instant;
import java.util.Arrays;
import java.util.Date;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.TreeMap;
import lombok.extern.slf4j.Slf4j;
import org.bremersee.security.authentication.JsonPathJwtConverter;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.test.context.support.WithSecurityContextFactory;
import org.springframework.util.StringUtils;

/**
 * The jwt authentication token security context factory.
 *
 * @author Christian Bremer
 */
@Slf4j
public class JwtAuthenticationTokenSecurityContextFactory
    implements WithSecurityContextFactory<WithJwtAuthenticationToken> {

  @Override
  public SecurityContext createSecurityContext(
      WithJwtAuthenticationToken withJwtAuthenticationToken) {

    SecurityContext context = SecurityContextHolder.createEmptyContext();
    Authentication authentication = createJwtConverter(withJwtAuthenticationToken)
        .convert(createSpringJwt(createJwt(withJwtAuthenticationToken)));
    context.setAuthentication(authentication);
    return context;
  }

  private Converter<Jwt, ? extends AbstractAuthenticationToken> createJwtConverter(
      WithJwtAuthenticationToken contextConfig) {
    JsonPathJwtConverter converter = new JsonPathJwtConverter();
    converter.setRolesValueSeparator(contextConfig.jwtConverter().rolesValueSeparator());
    converter.setRolePrefix(contextConfig.jwtConverter().rolePrefix());
    converter.setNameJsonPath(contextConfig.jwtConverter().nameJsonPath());
    converter.setRolesJsonPath(contextConfig.jwtConverter().rolesJsonPath());
    converter.setRolesValueList(contextConfig.jwtConverter().rolesValueList());
    return converter;
  }

  private JWT createJwt(WithJwtAuthenticationToken tokenValues) {
    JWTClaimsSet.Builder builder = new JWTClaimsSet.Builder()
        .audience(tokenValues.audience())
        .expirationTime(new Date(System.currentTimeMillis()
            + tokenValues.addMillisToExpirationTime()))
        .issuer(tokenValues.issuer())
        .issueTime(new Date(System.currentTimeMillis()
            + tokenValues.addMillisToIssueTime()))
        .jwtID(tokenValues.jwtId())
        .notBeforeTime(new Date(System.currentTimeMillis()
            + tokenValues.addMillisToNotBeforeTime()))
        .subject(tokenValues.subject());
    for (Map.Entry<String, Object> entry : createAdditionalClaims(tokenValues).entrySet()) {
      builder.claim(entry.getKey(), entry.getValue());
    }
    JWTClaimsSet claimsSet = builder.build();
    return new PlainJWT(claimsSet);
  }

  private Map<String, Object> createAdditionalClaims(WithJwtAuthenticationToken tokenValues) {
    Map<String, Object> map = new LinkedHashMap<>();
    for (Map.Entry<String, Object> entry : createPathMap(tokenValues).entrySet()) {
      addClaim(entry.getKey(), entry.getValue(), map);
    }
    return map;
  }

  private void addClaim(String path, Object value, Map<String, Object> map) {
    int index = path.indexOf('.');
    if (index < 0) {
      map.put(path, value);
    } else {
      String key = path.substring(0, index);
      String nextPath = path.substring(index + 1);
      Object child = map.get(key);
      if (child instanceof Map) {
        //noinspection unchecked
        Map<String, Object> childMap = (Map<String, Object>) child;
        addClaim(nextPath, value, childMap);
      } else {
        Map<String, Object> childMap = new LinkedHashMap<>();
        map.put(key, childMap);
        addClaim(nextPath, value, childMap);
      }
    }
  }

  private Map<String, Object> createPathMap(WithJwtAuthenticationToken tokenValues) {
    Map<String, Object> map = new TreeMap<>();
    if (StringUtils.hasText(trimPath(tokenValues.rolesPath()))) {
      map.put(
          tokenValues.rolesPath(),
          Arrays.asList(tokenValues.roles()));
    }
    if (StringUtils.hasText(trimPath(tokenValues.scopePath()))) {
      map.put(
          tokenValues.scopePath(),
          Arrays.asList(tokenValues.scope()));
    }
    if (StringUtils.hasText(trimPath(tokenValues.namePath()))) {
      map.put(tokenValues.namePath(), tokenValues.name());
    }
    if (StringUtils.hasText(trimPath(tokenValues.preferredUsernamePath()))) {
      map.put(
          tokenValues.preferredUsernamePath(),
          tokenValues.preferredUsername());
    }
    if (StringUtils.hasText(trimPath(tokenValues.givenNamePath()))) {
      map.put(
          tokenValues.givenNamePath(),
          tokenValues.givenName());
    }
    if (StringUtils.hasText(trimPath(tokenValues.familyNamePath()))) {
      map.put(
          tokenValues.familyNamePath(),
          tokenValues.familyName());
    }
    if (StringUtils.hasText(trimPath(tokenValues.emailPath()))) {
      map.put(tokenValues.emailPath(), tokenValues.email());
    }
    return map;
  }

  private String trimPath(String path) {
    if (path == null) {
      return null;
    }
    String tmp = path
        .replace("..", ".")
        .trim();
    while (tmp.startsWith(".")) {
      tmp = tmp.substring(1).trim();
    }
    while (tmp.endsWith(".")) {
      tmp = tmp.substring(0, tmp.length() - 1).trim();
    }
    return tmp;
  }

  private Jwt createSpringJwt(JWT jwt) {
    try {
      String tokenValue = jwt.serialize();
      Instant issuedAt = jwt.getJWTClaimsSet().getIssueTime().toInstant();
      Instant expiresAt = jwt.getJWTClaimsSet().getExpirationTime().toInstant();
      Map<String, Object> headers = jwt.getHeader().toJSONObject();
      Map<String, Object> claims = jwt.getJWTClaimsSet().toJSONObject();
      return new Jwt(tokenValue, issuedAt, expiresAt, headers, claims);
    } catch (Exception e) {
      throw new RuntimeException("Creating Spring Jwt failed.", e);
    }
  }
}