Skip to content

Commit

Permalink
Handling unbinding when nested txns are running on different DB [#161…
Browse files Browse the repository at this point in the history
…113716]
  • Loading branch information
Siddharth Srivastava committed Nov 21, 2018
1 parent 0246817 commit 0f8fc58
Show file tree
Hide file tree
Showing 12 changed files with 246 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@

package in.cleartax.dropwizard.sharding.hibernate;

import com.google.common.base.Preconditions;
import org.hibernate.context.spi.CurrentTenantIdentifierResolver;

import static com.google.common.base.Preconditions.checkNotNull;
import java.util.Stack;

public class DelegatingTenantResolver implements CurrentTenantIdentifierResolver {

private static DelegatingTenantResolver instance;

private ThreadLocal<CurrentTenantIdentifierResolver> delegate = new ThreadLocal<>();
private ThreadLocal<Stack<CurrentTenantIdentifierResolver>> delegate =
ThreadLocal.withInitial(Stack::new);

private DelegatingTenantResolver() {

Expand All @@ -43,18 +45,24 @@ public static DelegatingTenantResolver getInstance() {
}

public void setDelegate(CurrentTenantIdentifierResolver resolver) {
delegate.set(resolver);
if (resolver != null) {
delegate.get().add(resolver);
} else {
delegate.get().pop();
}
}

@Override
public String resolveCurrentTenantIdentifier() {
checkNotNull(delegate.get(), "Did you forget to set tenantId");

return delegate.get().resolveCurrentTenantIdentifier();
Preconditions.checkArgument(!delegate.get().isEmpty(), "Did you forget to set tenantId");
//noinspection ConstantConditions
return delegate.get().peek().resolveCurrentTenantIdentifier();
}

@Override
public boolean validateExistingCurrentSessions() {
return delegate.get().validateExistingCurrentSessions();
Preconditions.checkArgument(!delegate.get().isEmpty(), "Did you forget to set tenantId");
//noinspection ConstantConditions
return delegate.get().peek().validateExistingCurrentSessions();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright 2018 Saurabh Agrawal (Cleartax)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package in.cleartax.dropwizard.sharding.hibernate;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Value;
import org.hibernate.SessionFactory;

@AllArgsConstructor
@Value
@Builder
public class MultiTenantSessionSource {
private SessionFactory sessionFactory;
private MultiTenantDataSourceFactory dataSourceFactory;
private MultiTenantUnitOfWorkAwareProxyFactory unitOfWorkAwareProxyFactory;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
package in.cleartax.dropwizard.sharding.hibernate;

import io.dropwizard.hibernate.HibernateBundle;
import io.dropwizard.hibernate.UnitOfWork;
import org.hibernate.Session;
import org.hibernate.SessionFactory;
import org.hibernate.Transaction;
import org.hibernate.context.internal.ManagedSessionContext;

import javax.annotation.Nullable;
import java.util.Map;
import java.util.Stack;

import static com.google.common.base.Preconditions.checkNotNull;
import static java.util.Objects.requireNonNull;

/**
* Created on 14/11/18
*/
public class MultiTenantUnitOfWorkAspect {
private static ThreadLocal<Stack<Session>> CONTEXT_OPEN_SESSIONS = ThreadLocal.withInitial(Stack::new);

private final Map<String, SessionFactory> sessionFactories;
// Context variables
@Nullable
private UnitOfWork unitOfWork;
@Nullable
private Session session;
@Nullable
private SessionFactory sessionFactory;

public MultiTenantUnitOfWorkAspect(Map<String, SessionFactory> sessionFactories) {
this.sessionFactories = sessionFactories;
}

public void beforeStart(@Nullable UnitOfWork unitOfWork) {
if (unitOfWork == null) {
return;
}
this.unitOfWork = unitOfWork;

sessionFactory = sessionFactories.get(unitOfWork.value());
if (sessionFactory == null) {
// If the user didn't specify the name of a session factory,
// and we have only one registered, we can assume that it's the right one.
if (unitOfWork.value().equals(HibernateBundle.DEFAULT_NAME) && sessionFactories.size() == 1) {
sessionFactory = sessionFactories.values().iterator().next();
} else {
throw new IllegalArgumentException("Unregistered Hibernate bundle: '" + unitOfWork.value() + "'");
}
}
session = sessionFactory.openSession();
assert session != null;
try {
configureSession();
bind(session);
beginTransaction(unitOfWork, session);
} catch (Throwable th) {
session.close();
session = null;
unbind(sessionFactory);
throw th;
}
}

public void afterEnd() {
if (unitOfWork == null || session == null) {
return;
}

try {
commitTransaction(unitOfWork, session);
} catch (Exception e) {
rollbackTransaction(unitOfWork, session);
throw e;
}
// We should not close the session to let the lazy loading work during serializing a response to the client.
// If the response successfully serialized, then the session will be closed by the `onFinish` method
}

public void onError() {
if (unitOfWork == null || session == null) {
return;
}

try {
rollbackTransaction(unitOfWork, session);
} finally {
onFinish();
}
}

public void onFinish() {
try {
if (session != null) {
session.close();
}
} finally {
session = null;
unbind(sessionFactory);
}
}

protected void configureSession() {
checkNotNull(unitOfWork);
checkNotNull(session);
session.setDefaultReadOnly(unitOfWork.readOnly());
session.setCacheMode(unitOfWork.cacheMode());
session.setHibernateFlushMode(unitOfWork.flushMode());
}

private void beginTransaction(UnitOfWork unitOfWork, Session session) {
if (!unitOfWork.transactional()) {
return;
}
session.beginTransaction();
}

private void rollbackTransaction(UnitOfWork unitOfWork, Session session) {
if (!unitOfWork.transactional()) {
return;
}
final Transaction txn = session.getTransaction();
if (txn != null && txn.getStatus().canRollback()) {
txn.rollback();
}
}

private void commitTransaction(UnitOfWork unitOfWork, Session session) {
if (!unitOfWork.transactional()) {
return;
}
final Transaction txn = session.getTransaction();
if (txn != null && txn.getStatus().canRollback()) {
txn.commit();
}
}

protected Session getSession() {
return requireNonNull(session);
}

protected SessionFactory getSessionFactory() {
return requireNonNull(sessionFactory);
}

private void bind(Session session) {
CONTEXT_OPEN_SESSIONS.get().push(session);
ManagedSessionContext.bind(session);
}

private void unbind(SessionFactory sessionFactory) {
ManagedSessionContext.unbind(sessionFactory);
// This defensive check is needed as in case of exception onFinish gets called multiple times.
if (!CONTEXT_OPEN_SESSIONS.get().isEmpty()) {
CONTEXT_OPEN_SESSIONS.get().pop();
}
if (!CONTEXT_OPEN_SESSIONS.get().isEmpty()) {
ManagedSessionContext.bind(CONTEXT_OPEN_SESSIONS.get().peek());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package in.cleartax.dropwizard.sharding.hibernate;

import com.google.common.collect.ImmutableMap;
import io.dropwizard.hibernate.UnitOfWorkAspect;
import org.hibernate.SessionFactory;

public class MultiTenantUnitOfWorkAwareProxyFactory {
Expand All @@ -39,15 +38,15 @@ public MultiTenantUnitOfWorkAwareProxyFactory(MultiTenantHibernateBundle<?>... b
/**
* @return a new aspect
*/
public UnitOfWorkAspect newAspect() {
return new UnitOfWorkAspect(sessionFactories);
public MultiTenantUnitOfWorkAspect newAspect() {
return new MultiTenantUnitOfWorkAspect(sessionFactories);
}

/**
* @param sessionFactories
* @return a new aspect
*/
public UnitOfWorkAspect newAspect(ImmutableMap<String, SessionFactory> sessionFactories) {
return new UnitOfWorkAspect(sessionFactories);
public MultiTenantUnitOfWorkAspect newAspect(ImmutableMap<String, SessionFactory> sessionFactories) {
return new MultiTenantUnitOfWorkAspect(sessionFactories);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@

import in.cleartax.dropwizard.sharding.hibernate.ConstTenantIdentifierResolver;
import in.cleartax.dropwizard.sharding.hibernate.DelegatingTenantResolver;
import in.cleartax.dropwizard.sharding.hibernate.MultiTenantUnitOfWorkAspect;
import in.cleartax.dropwizard.sharding.hibernate.MultiTenantUnitOfWorkAwareProxyFactory;
import io.dropwizard.hibernate.UnitOfWork;
import io.dropwizard.hibernate.UnitOfWorkAspect;
import lombok.AllArgsConstructor;
import org.hibernate.SessionFactory;
import org.hibernate.context.internal.ManagedSessionContext;
Expand All @@ -39,7 +39,7 @@ public T start(boolean reUseSession, UnitOfWork unitOfWork) throws Throwable {
return run();
}
DelegatingTenantResolver.getInstance().setDelegate(tenantIdentifierResolver);
UnitOfWorkAspect aspect = proxyFactory.newAspect();
MultiTenantUnitOfWorkAspect aspect = proxyFactory.newAspect();
Exception ex = null;
T result = null;
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,19 @@
import com.google.inject.Inject;
import com.google.inject.matcher.Matchers;
import in.cleartax.dropwizard.sharding.hibernate.ConstTenantIdentifierResolver;
import in.cleartax.dropwizard.sharding.hibernate.MultiTenantDataSourceFactory;
import in.cleartax.dropwizard.sharding.hibernate.MultiTenantUnitOfWorkAwareProxyFactory;
import in.cleartax.dropwizard.sharding.hibernate.DelegatingTenantResolver;
import in.cleartax.dropwizard.sharding.hibernate.MultiTenantSessionSource;
import in.cleartax.dropwizard.sharding.providers.ShardKeyProvider;
import in.cleartax.dropwizard.sharding.resolvers.bucket.BucketResolver;
import in.cleartax.dropwizard.sharding.resolvers.shard.ShardResolver;
import io.dropwizard.hibernate.UnitOfWork;
import lombok.extern.slf4j.Slf4j;
import org.aopalliance.intercept.MethodInterceptor;
import org.aopalliance.intercept.MethodInvocation;
import org.hibernate.SessionFactory;

import javax.inject.Named;
import java.util.Objects;

@Slf4j
public class UnitOfWorkModule extends AbstractModule {

@Override
Expand All @@ -45,31 +45,28 @@ protected void configure() {

private static class UnitOfWorkInterceptor implements MethodInterceptor {

@Inject
MultiTenantUnitOfWorkAwareProxyFactory proxyFactory;
@Inject
@Named("session")
SessionFactory sessionFactory;
@Inject
BucketResolver bucketResolver;
@Inject
ShardResolver shardResolver;
@Inject
ShardKeyProvider shardKeyProvider;
@Inject
@Named("multiTenantConfiguration")
MultiTenantDataSourceFactory multiTenantDataSourceFactory;
MultiTenantSessionSource multiTenantSessionSource;

private String getTenantIdentifier(MethodInvocation mi) {
boolean useDefaultShard = mi.getMethod().isAnnotationPresent(DefaultTenant.class);
String tenantId;
if (!useDefaultShard && multiTenantDataSourceFactory.isAllowMultipleTenants()) {
if (!useDefaultShard && multiTenantSessionSource.getDataSourceFactory().isAllowMultipleTenants()) {
String shardKey = shardKeyProvider.getKey();
Objects.requireNonNull(shardKey, "No tenant-identifier set for this session");
String bucketId = bucketResolver.resolve(shardKey);
tenantId = shardResolver.resolve(bucketId);
if (shardKey != null) {
String bucketId = bucketResolver.resolve(shardKey);
tenantId = shardResolver.resolve(bucketId);
} else {
tenantId = DelegatingTenantResolver.getInstance().resolveCurrentTenantIdentifier();
}
} else {
tenantId = multiTenantDataSourceFactory.getDefaultTenant();
tenantId = multiTenantSessionSource.getDataSourceFactory().getDefaultTenant();
}
return tenantId;
}
Expand All @@ -79,7 +76,8 @@ public Object invoke(MethodInvocation mi) throws Throwable {
String tenantId = getTenantIdentifier(mi);
Objects.requireNonNull(tenantId, "No tenant-identifier found for this session");

TransactionRunner runner = new TransactionRunner(proxyFactory, sessionFactory,
TransactionRunner runner = new TransactionRunner(multiTenantSessionSource.getUnitOfWorkAwareProxyFactory(),
multiTenantSessionSource.getSessionFactory(),
new ConstTenantIdentifierResolver(tenantId)) {
@Override
public Object run() throws Throwable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import in.cleartax.dropwizard.sharding.resolvers.shard.ShardResolver;
import in.cleartax.dropwizard.sharding.transactions.DefaultTenant;
import in.cleartax.dropwizard.sharding.transactions.ReuseSession;
import in.cleartax.dropwizard.sharding.utils.dao.BucketToShardMappingDAO;
import io.dropwizard.hibernate.UnitOfWork;
import lombok.RequiredArgsConstructor;
Expand All @@ -34,6 +35,7 @@ public class DbBasedShardResolver implements ShardResolver {
@Override
@UnitOfWork
@DefaultTenant
@ReuseSession
public String resolve(String bucketId) {
Optional<String> shardId = dao.getShardId(bucketId);
if (!shardId.isPresent()) {
Expand Down
Loading

0 comments on commit 0f8fc58

Please sign in to comment.