Skip to content

Commit

Permalink
Add Authorization Proxy Support
Browse files Browse the repository at this point in the history
Closes gh-14596
  • Loading branch information
jzheaux committed Mar 1, 2024
1 parent bade66e commit 6d290c9
Show file tree
Hide file tree
Showing 2 changed files with 283 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.authorization.object;

import java.lang.reflect.Array;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;

import org.springframework.aop.Advisor;
import org.springframework.aop.framework.ProxyFactory;
import org.springframework.core.annotation.AnnotationAwareOrderComparator;
import org.springframework.util.ClassUtils;

public final class AuthorizationProxyFactory {

private final Collection<Advisor> advisors;

public AuthorizationProxyFactory(Advisor... advisors) {
this.advisors = List.of(advisors);
}

public AuthorizationProxyFactory(Collection<Advisor> advisors) {
this.advisors = List.copyOf(advisors);
}

public AuthorizationProxyFactory withAdvisors(Advisor... advisors) {
List<Advisor> merged = new ArrayList<>(this.advisors.size() + 1);
merged.addAll(this.advisors);
merged.addAll(List.of(advisors));
AnnotationAwareOrderComparator.sort(merged);
return new AuthorizationProxyFactory(merged);
}

public Object proxy(Object target) {
if (target == null) {
return target;
}
if (ClassUtils.isSimpleValueType(target.getClass())) {
return target;
}
if (target instanceof Iterator<?> iterator) {
return proxyIterator(iterator);
}
if (target instanceof Collection<?> collection) {
return proxyCollection(collection);
}
if (target.getClass().isArray()) {
return proxyArray((Object[]) target);
}
if (target instanceof Map) {
return proxyMap((Map<?, ?>) target);
}
if (target instanceof Stream) {
return proxyStream((Stream<?>) target);
}
return proxySingle(target);
}

private <T> Iterator<T> proxyIterator(Iterator<T> iterator) {
return new Iterator<>() {
@Override
public boolean hasNext() {
return iterator.hasNext();
}

@Override
public T next() {
return proxySingle(iterator.next());
}
};
}

private <T> Collection<T> proxyCollection(Collection<T> collection) {
Collection<T> proxies = new ArrayList<>(collection.size());
for (T toProxy : collection) {
proxies.add(proxySingle(toProxy));
}
collection.clear();
collection.addAll(proxies);
return proxies;
}

private Object proxyArray(Object[] objects) {
List<Object> retain = new ArrayList<>(objects.length);
for (Object object : objects) {
retain.add(proxySingle(object));
}
Object[] proxies = (Object[]) Array.newInstance(objects.getClass().getComponentType(), retain.size());
for (int i = 0; i < retain.size(); i++) {
proxies[i] = retain.get(i);
}
return proxies;
}

private <K, V> Object proxyMap(Map<K, V> entries) {
Map<K, V> proxies = new LinkedHashMap<>(entries.size());
for (Map.Entry<K, V> entry : entries.entrySet()) {
proxies.put(entry.getKey(), proxySingle(entry.getValue()));
}
entries.clear();
entries.putAll(proxies);
return entries;
}

private Object proxyStream(Stream<?> stream) {
return stream.map(this::proxySingle).onClose(stream::close);
}

private <T> T proxySingle(T target) {
ProxyFactory factory = new ProxyFactory(target);
factory.addAdvisors(this.advisors);
factory.setProxyTargetClass(!Modifier.isFinal(target.getClass().getModifiers()));
return (T) factory.getProxy();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.authorization.object;

import org.junit.jupiter.api.Test;

import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.access.prepost.PreAuthorize;
import org.springframework.security.authentication.TestAuthentication;
import org.springframework.security.authorization.method.AuthorizationManagerBeforeMethodInterceptor;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;

public class AuthorizationProxyFactoryTests {

private final Authentication user = TestAuthentication.authenticatedUser();

private final Authentication admin = TestAuthentication.authenticatedAdmin();

@Test
public void proxyWhenPreAuthorizeThenHonors() {
SecurityContextHolder.getContext().setAuthentication(this.user);
AuthorizationManagerBeforeMethodInterceptor preAuthorize = AuthorizationManagerBeforeMethodInterceptor
.preAuthorize();
AuthorizationProxyFactory factory = new AuthorizationProxyFactory(preAuthorize);
Flight flight = new Flight();
assertThat(flight.getAltitude()).isEqualTo(35000d);
Flight secured = (Flight) factory.proxy(flight);
assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> secured.getAltitude());
SecurityContextHolder.clearContext();
}

@Test
public void proxyWhenPreAuthorizeOnInterfaceThenHonors() {
SecurityContextHolder.getContext().setAuthentication(this.user);
AuthorizationManagerBeforeMethodInterceptor preAuthorize = AuthorizationManagerBeforeMethodInterceptor
.preAuthorize();
AuthorizationProxyFactory factory = new AuthorizationProxyFactory(preAuthorize);
User user = new User("user", "First", "Last");
assertThat(user.getFirstName()).isEqualTo("First");
User secured = (User) factory.proxy(user);
assertThat(secured.getFirstName()).isEqualTo("First");
SecurityContextHolder.getContext().setAuthentication(authenticated("wrong"));
assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> secured.getFirstName());
SecurityContextHolder.getContext().setAuthentication(this.admin);
assertThat(secured.getFirstName()).isEqualTo("First");
SecurityContextHolder.clearContext();
}

@Test
public void proxyWhenPreAuthorizeOnRecordThenHonors() {
SecurityContextHolder.getContext().setAuthentication(this.user);
AuthorizationManagerBeforeMethodInterceptor preAuthorize = AuthorizationManagerBeforeMethodInterceptor
.preAuthorize();
AuthorizationProxyFactory factory = new AuthorizationProxyFactory(preAuthorize);
HasSecret repo = new Repository("secret");
assertThat(repo.secret()).isEqualTo("secret");
HasSecret secured = (HasSecret) factory.proxy(repo);
assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> secured.secret());
SecurityContextHolder.getContext().setAuthentication(this.user);
assertThat(repo.secret()).isEqualTo("secret");
SecurityContextHolder.clearContext();
}

private Authentication authenticated(String user, String... authorities) {
return TestAuthentication.authenticated(TestAuthentication.withUsername(user).authorities(authorities).build());
}

static class Flight {

@PreAuthorize("hasRole('PILOT')")
Double getAltitude() {
return 35000d;
}

}

interface Identifiable {

String getId();

@PreAuthorize("authentication.name == this.id || hasRole('ADMIN')")
String getFirstName();

@PreAuthorize("authentication.name == this.id || hasRole('ADMIN')")
String getLastName();

}

static class User implements Identifiable {

private final String id;

private final String firstName;

private final String lastName;

User(String id, String firstName, String lastName) {
this.id = id;
this.firstName = firstName;
this.lastName = lastName;
}

@Override
public String getId() {
return this.id;
}

@Override
public String getFirstName() {
return this.firstName;
}

@Override
public String getLastName() {
return this.lastName;
}

}

interface HasSecret {

String secret();

}

record Repository(@PreAuthorize("hasRole('ADMIN')") String secret) implements HasSecret {
}

}

0 comments on commit 6d290c9

Please sign in to comment.