Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -729,13 +729,13 @@ public List<RestHandler> getRestHandlers(
}

@Override
public UnaryOperator<RestHandler> getRestHandlerWrapper(final ThreadContext threadContext, Set<RestHeaderDefinition> headersToCopy) {
public UnaryOperator<RestHandler> getRestHandlerWrapper(final ThreadContext threadContext, Set<RestHeaderDefinition> headersToCopy, Set<String> transientsToCopy) {

if (client || disabled || SSLConfig.isSslOnlyMode()) {
return (rh) -> rh;
}

return (rh) -> securityRestHandler.wrap(rh, adminDns, headersToCopy);
return (rh) -> securityRestHandler.wrap(rh, adminDns, headersToCopy, transientsToCopy);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
import org.opensearch.security.support.ConfigConstants;
import org.opensearch.security.support.HTTPHelper;
import org.opensearch.security.user.User;
import org.opensearch.telemetry.tracing.TracerContextStorage;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.client.node.NodeClient;

Expand Down Expand Up @@ -127,11 +128,13 @@ public SecurityRestFilter(
class AuthczRestHandler extends DelegatingRestHandler {
private final AdminDNs adminDNs;
private final Set<RestHeaderDefinition> headersToCopy;
private final Set<String> transientsToCopy;

public AuthczRestHandler(RestHandler original, AdminDNs adminDNs, Set<RestHeaderDefinition> headersToCopy) {
public AuthczRestHandler(RestHandler original, AdminDNs adminDNs, Set<RestHeaderDefinition> headersToCopy, Set<String> transientsToCopy) {
super(original);
this.adminDNs = adminDNs;
this.headersToCopy = headersToCopy;
this.transientsToCopy = transientsToCopy;
}

@Override
Expand Down Expand Up @@ -159,6 +162,18 @@ public void handleRequest(RestRequest request, RestChannel channel, NodeClient c
tmpHeaders.put(header.getName(), value);
}
}

Map<String, Object> trasients = null;
for (String transientValue : transientsToCopy) {
final Object value = threadContext.getTransient(transientValue);
if (value != null) {
if (trasients == null) {
trasients = new HashMap<>();
}
trasients.put(transientValue, value);
}
}

storedContext.restore();

if (tmpHeaders != null) {
Expand All @@ -167,6 +182,13 @@ public void handleRequest(RestRequest request, RestChannel channel, NodeClient c
}
threadContext.putHeader(OPENSEARCH_SECURITY_REQUEST_HEADERS, String.join(",", tmpHeaders.keySet()));
}
if(trasients != null) {
for (Map.Entry<String, Object> transientVal : trasients.entrySet()) {
if (threadContext.getTransient(transientVal.getKey()) == null) {
threadContext.putTransient(transientVal.getKey(), transientVal.getValue());
}
}
}
});

NettyAttribute.popFrom(request, Netty4HttpRequestHeaderVerifier.UNCONSUMED_PARAMS).ifPresent(unconsumedParams -> {
Expand Down Expand Up @@ -263,8 +285,8 @@ RestRequest maybeFilterRestRequest(RestRequest request) throws IOException {
* See {@link AllowlistApiAction} for the implementation of this API.
* SuperAdmin is identified by credentials, which can be passed in the curl request.
*/
public RestHandler wrap(RestHandler original, AdminDNs adminDNs, Set<RestHeaderDefinition> headersToCopy) {
return new AuthczRestHandler(original, adminDNs, headersToCopy);
public RestHandler wrap(RestHandler original, AdminDNs adminDNs, Set<RestHeaderDefinition> headersToCopy, Set<String> transientsToCopy) {
return new AuthczRestHandler(original, adminDNs, headersToCopy, transientsToCopy);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@

import java.nio.file.Path;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
import java.util.List;

import org.junit.Before;
import org.junit.Test;
Expand All @@ -21,21 +24,20 @@
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.*;
import org.opensearch.security.auditlog.AuditLog;
import org.opensearch.security.auth.BackendRegistry;
import org.opensearch.security.configuration.AdminDNs;
import org.opensearch.security.configuration.CompatConfig;
import org.opensearch.security.privileges.RestLayerPrivilegesEvaluator;
import org.opensearch.security.ssl.transport.PrincipalExtractor;
import org.opensearch.security.ssl.http.netty.Netty4HttpRequestHeaderVerifier;
import org.opensearch.telemetry.tracing.Span;
import org.opensearch.telemetry.tracing.TracerContextStorage;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.client.node.NodeClient;

import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
Expand All @@ -46,6 +48,7 @@ public class SecurityRestFilterUnitTests {

SecurityRestFilter sf;
RestHandler testRestHandler;
ThreadPool threadPool;

class TestRestHandler implements RestHandler {

Expand All @@ -60,6 +63,7 @@ public void setUp() throws NoSuchMethodException {
testRestHandler = new TestRestHandler();

ThreadPool tp = spy(new ThreadPool(Settings.builder().put("node.name", "mock").build()));
this.threadPool = tp;
doReturn(new ThreadContext(Settings.EMPTY)).when(tp).getThreadContext();

sf = new SecurityRestFilter(
Expand All @@ -81,7 +85,7 @@ public void setUp() throws NoSuchMethodException {
public void testSecurityRestFilterWrap() throws Exception {
AdminDNs adminDNs = mock(AdminDNs.class);

RestHandler wrappedRestHandler = sf.wrap(testRestHandler, adminDNs, new HashSet<>());
RestHandler wrappedRestHandler = sf.wrap(testRestHandler, adminDNs, new HashSet<>(), new HashSet<>());

assertTrue(wrappedRestHandler instanceof SecurityRestFilter.AuthczRestHandler);
assertFalse(wrappedRestHandler instanceof TestRestHandler);
Expand All @@ -93,7 +97,7 @@ public void testDoesCallDelegateOnSuccessfulAuthorization() throws Exception {
AdminDNs adminDNs = mock(AdminDNs.class);

RestHandler testRestHandlerSpy = spy(testRestHandler);
RestHandler wrappedRestHandler = filterSpy.wrap(testRestHandlerSpy, adminDNs, new HashSet<>());
RestHandler wrappedRestHandler = filterSpy.wrap(testRestHandlerSpy, adminDNs, new HashSet<>(), new HashSet<>());

doReturn(false).when(filterSpy).userIsSuperAdmin(any(), any());

Expand All @@ -104,4 +108,69 @@ public void testDoesCallDelegateOnSuccessfulAuthorization() throws Exception {

// unit tests for restPathMatches are in RestPathMatchesTests.java

//Test that current_span transient is preserved after context restoration.
@Test
public void testCurrentSpanTransientPreservedAfterRestore() throws Exception {
ThreadContext threadContext = threadPool.getThreadContext();
// Handler verifies span is present
RestHandler testHandler = (request, channel, client) -> {
assertNotNull("CURRENT_SPAN should be preserved",
threadContext.getTransient(TracerContextStorage.CURRENT_SPAN));
};

Set<String> transientsToCopy = new HashSet<>(List.of(TracerContextStorage.CURRENT_SPAN));
RestHandler wrappedRestHandler = sf.wrap(testHandler, mock(AdminDNs.class), new HashSet<>(), transientsToCopy);
RestRequest request = addRelevantMocksAndGetRequest(threadContext);

threadContext.putTransient(TracerContextStorage.CURRENT_SPAN, mock(Span.class));
wrappedRestHandler.handleRequest(request, mock(RestChannel.class), mock(NodeClient.class));

assertNotNull("current_span should be preserved after handleRequest completes",
threadContext.getTransient(TracerContextStorage.CURRENT_SPAN));

}

// Current span is present in context ,not in transientsToCopy, hence we should NOT find it later
@Test
public void testCurrentSpanTransientNotPreservedAfterRestore() throws Exception {
ThreadContext threadContext = threadPool.getThreadContext();

// Handler verifies span is absent
RestHandler testHandler = (request, channel, client) -> {
assertNull("CURRENT_SPAN should NOT be preserved",
threadContext.getTransient(TracerContextStorage.CURRENT_SPAN));
};

Set<String> transientsToCopy = new HashSet<>();
RestHandler wrappedRestHandler = sf.wrap(testHandler, mock(AdminDNs.class), new HashSet<>(), transientsToCopy);
RestRequest request = addRelevantMocksAndGetRequest(threadContext);

threadContext.putTransient(TracerContextStorage.CURRENT_SPAN, mock(Span.class));

wrappedRestHandler.handleRequest(request, mock(RestChannel.class), mock(NodeClient.class));
assertNull("current_span should NOT be preserved after handleRequest completes as its not present in transientsToCopy",
threadContext.getTransient(TracerContextStorage.CURRENT_SPAN));
}


@SuppressWarnings("unchecked")
private RestRequest addRelevantMocksAndGetRequest(ThreadContext threadContext ) {
// Mock Netty attributes
RestRequest request = mock(RestRequest.class);
org.opensearch.http.HttpChannel httpChannel = mock(org.opensearch.http.HttpChannel.class);
io.netty.channel.Channel nettyChannel = mock(io.netty.channel.Channel.class);

doReturn(httpChannel).doReturn(httpChannel).doReturn(null).when(request).getHttpChannel();
doReturn(Optional.of(nettyChannel)).when(httpChannel).get("channel", io.netty.channel.Channel.class);

io.netty.util.Attribute<ThreadContext.StoredContext> contextAttr = mock(io.netty.util.Attribute.class);
io.netty.util.Attribute<SecurityResponse> earlyResponseAttr = mock(io.netty.util.Attribute.class);
doReturn(contextAttr).when(nettyChannel).attr(Netty4HttpRequestHeaderVerifier.CONTEXT_TO_RESTORE);
doReturn(earlyResponseAttr).when(nettyChannel).attr(Netty4HttpRequestHeaderVerifier.EARLY_RESPONSE);
doReturn(null).when(earlyResponseAttr).getAndSet(null);

ThreadContext.StoredContext storedContext = threadContext.newStoredContext(true);
doReturn(storedContext).when(contextAttr).getAndSet(null);
return request;
}
}