Newer
Older
simple-database / src / nl / astraeus / database / Persister.java
package nl.astraeus.database;

import nl.astraeus.database.jdbc.ConnectionPool;

import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;
import java.util.HashMap;
import java.util.Map;

/**
 * Date: 11/13/13
 * Time: 9:28 PM
 */
public class Persister {

    private static ThreadLocal<Transaction> transactions = new ThreadLocal<Transaction>();
    private static Map<Class<?>, ObjectPersister> objectPersisters = new HashMap<>();

    static {
        ConnectionPool.get().setConnectionProvider(new ConnectionProvider() {
            @Override
            public Connection getConnection() {
                try {
                    Class.forName("org.h2.Driver");
                    Class.forName("nl.astraeus.jdbc.Driver");

                    Connection connection = DriverManager.getConnection("jdbc:stat::jdbc:h2:mem:test", "sa", "");
                    connection.setAutoCommit(false);

                    return connection;
                } catch (ClassNotFoundException e) {
                    throw new IllegalStateException(e);
                } catch (SQLException e) {
                    throw new IllegalStateException(e);
                }
            }
        });
    }

    protected static Connection getConnection() {
        if (transactions.get() == null) {
            throw new IllegalStateException("No transaction active!");
        }

        return transactions.get().getConnection();
    }

    public static void begin() {
        transactions.set(new Transaction(ConnectionPool.get().getConnection()));
    }

    public static void commit() {
        if (transactions.get() != null) {
            Connection connection = transactions.get().getConnection();
            try {
                connection.commit();
            } catch (SQLException e) {
                throw new IllegalStateException(e);
            } finally {
                try {
                    connection.close();
                } catch (SQLException e) {
                    e.printStackTrace();
                }
                transactions.set(null);
            }
        }
    }

    public static void rollback() {
        if (transactions.get() != null) {
            try {
                transactions.get().getConnection().rollback();
            } catch (SQLException e) {
                throw new IllegalStateException(e);
            } finally {
                try {
                    transactions.get().getConnection().close();
                } catch (SQLException e) {
                    e.printStackTrace();
                }
                transactions.set(null);
            }
        }
    }

    public static void execute(Runnable runnable) {
        try {
            begin();

            runnable.run();

            commit();
        } finally {
            if (transactionActive()) {
                rollback();
            }
        }
    }

    public static void insert(Object obj) {
        getObjectPersister(obj.getClass()).insert(obj);
    }

    public static void update(Object obj) {
        getObjectPersister(obj.getClass()).update(obj);
    }

    public static void delete(Object obj) {
        getObjectPersister(obj.getClass()).delete(obj);
    }

    public static <T> T find(Class<T> cls, long id) {
        return (T)getObjectPersister(cls).find(id);
    }

    private static ObjectPersister getObjectPersister(Class<?> cls) {
        ObjectPersister result = objectPersisters.get(cls);

        if (result == null) {
            synchronized (Persister.class) {
                result = objectPersisters.get(cls);
                if (result == null) {
                    result = new ObjectPersister(cls);

                    objectPersisters.put(cls, result);
                }
            }
        }

        return result;
    }

    public static boolean transactionActive() {
        return transactions.get() != null;
    }

}