Simplify session management.

This commit is contained in:
Alexis Lahouze 2015-11-24 21:31:24 +01:00
parent 5b43cd763e
commit 762efc64c8
8 changed files with 70 additions and 124 deletions

View File

@ -46,30 +46,6 @@ app.debug = config.debug
db = SQLAlchemy(app)
@contextmanager
def session_scope():
from accountant import db
session = db.session
try:
yield session
session.commit()
except:
session.rollback()
raise
finally:
session.close()
def session_aware(f):
def wrapper(*args, **kwargs):
with session_scope() as session:
kwargs['session'] = session
return f(*args, **kwargs)
return wrapper
# Must be after db declaration because the blueprints may need it.
from .api import api
from .frontend import frontend, frontend_js, frontend_css

View File

@ -32,8 +32,8 @@ class Account(db.Model):
self.authorized_overdraft = authorized_overdraft
@classmethod
def query(cls, session, begin=None, end=None):
status_query = session.query(
def query(cls, begin=None, end=None):
status_query = db.session.query(
Operation.account_id,
func.sum(Operation.value).label("future"),
func.sum(
@ -55,7 +55,7 @@ class Account(db.Model):
).subquery()
if begin and end:
balance_query = session.query(
balance_query = db.session.query(
Operation.account_id,
func.sum(
case(
@ -77,7 +77,7 @@ class Account(db.Model):
Operation.account_id
).subquery()
query = session.query(
query = db.session.query(
cls.id,
cls.name,
cls.authorized_overdraft,
@ -93,7 +93,7 @@ class Account(db.Model):
balance_query, balance_query.c.account_id == cls.id
)
else:
query = session.query(
query = db.session.query(
cls.id,
cls.name,
cls.authorized_overdraft,

View File

@ -102,15 +102,15 @@ class Operation(db.Model):
self.canceled = canceled
@classmethod
def query(cls, session, begin=None, end=None):
def query(cls, begin=None, end=None):
# We have to use a join because the sold is not computed from the
# begining.
base_query = session.query(
base_query = db.session.query(
cls.id,
cls.sold
).subquery()
query = session.query(
query = db.session.query(
cls.id,
cls.operation_date,
cls.label,
@ -139,13 +139,13 @@ class Operation(db.Model):
return query
@classmethod
def get_categories_for_range(cls, session, account, begin, end):
def get_categories_for_range(cls, account, begin, end):
if isinstance(account, int) or isinstance(account, str):
account_id = account
else:
account_id = account.id
query = session.query(
query = db.session.query(
cls.category,
func.sum(
case([(func.sign(cls.value) == -1, cls.value)], else_=0)
@ -168,7 +168,7 @@ class Operation(db.Model):
return query
@classmethod
def get_ohlc_per_day_for_range(cls, session, account, begin=None, end=None):
def get_ohlc_per_day_for_range(cls, account, begin=None, end=None):
if isinstance(account, int) or isinstance(account, str):
account_id = account
else:
@ -182,7 +182,7 @@ class Operation(db.Model):
previous = sold - cls.value
subquery = session.query(
subquery = db.session.query(
cls.operation_date,
sold.label("sold"),
previous.label("previous")
@ -191,7 +191,7 @@ class Operation(db.Model):
cls.canceled == false()
).subquery()
query = session.query(
query = db.session.query(
subquery.c.operation_date,
func.first_value(subquery.c.previous).over(
partition_by=subquery.c.operation_date

View File

@ -62,8 +62,8 @@ class ScheduledOperation(db.Model):
self.category = category
@classmethod
def query(cls, session):
return session.query(
def query(cls):
return db.session.query(
cls
).order_by(
desc(cls.day),
@ -71,9 +71,9 @@ class ScheduledOperation(db.Model):
cls.label,
)
def reschedule(self, session):
def reschedule(self):
# 1) delete unconfirmed operations for this account.
session.query(
db.session.query(
Operation
).filter(
Operation.scheduled_operation_id == self.id,
@ -106,7 +106,7 @@ class ScheduledOperation(db.Model):
# Search if a close operation exists.
if not session.query(
if not db.session.query(
Operation
).filter(
Operation.account_id == self.account_id,
@ -127,4 +127,4 @@ class ScheduledOperation(db.Model):
canceled=False
)
session.add(operation)
db.session.add(operation)

View File

@ -55,7 +55,7 @@ class User(UserMixin, db.Model):
return serializer.dumps({'id': self.id})
@classmethod
def verify_auth_token(cls, session, token):
def verify_auth_token(cls, token):
serializer = Serializer(app.config['SECRET_KEY'])
try:
@ -65,5 +65,5 @@ class User(UserMixin, db.Model):
except BadSignature:
return None
user = cls.query(session).get(data['id'])
user = cls.query().get(data['id'])
return user

View File

@ -20,7 +20,7 @@ from flask.ext.restful import Resource, fields, reqparse, marshal_with_field
from sqlalchemy.orm.exc import NoResultFound
from accountant import session_aware
from accountant import db
from .. import api_api
@ -54,17 +54,15 @@ date_parser.add_argument('end',
class AccountListResource(Resource):
@session_aware
@marshal_with_field(fields.List(Object(resource_fields)))
def get(self, session):
def get(self):
"""
Returns accounts with their balances.
"""
return Account.query(session).all(), 200
return Account.query().all(), 200
@session_aware
@marshal_with_field(Object(resource_fields))
def post(self, session):
def post(self):
"""
Create a new account.
"""
@ -72,15 +70,13 @@ class AccountListResource(Resource):
account = Account(**kwargs)
session.add(account)
db.session.add(account)
# Flush session to have id in account.
session.flush()
db.session.flush()
# Return account with solds.
return Account.query(
session
).filter(
return Account.query().filter(
Account.id == account.id
).one(), 201
@ -92,9 +88,8 @@ class AccountListResource(Resource):
class AccountResource(Resource):
@session_aware
@marshal_with_field(Object(resource_fields))
def get(self, account_id, session):
def get(self, account_id):
"""
Get account.
"""
@ -102,36 +97,34 @@ class AccountResource(Resource):
try:
return Account.query(
session, **kwargs
**kwargs
).filter(
Account.id == account_id
).one()
except NoResultFound:
return None, 404
@session_aware
@marshal_with_field(Object(resource_fields))
def delete(self, account_id, session):
def delete(self, account_id):
# Need to get the object to update it.
account = session.query(Account).get(account_id)
account = db.session.query(Account).get(account_id)
if not account:
return None, 404
session.delete(account)
db.session.delete(account)
return None, 204
@session_aware
@marshal_with_field(Object(resource_fields))
def post(self, account_id, session):
def post(self, account_id):
kwargs = parser.parse_args()
assert (id not in kwargs or kwargs.id is None
or kwargs.id == account_id)
# Need to get the object to update it.
account = session.query(Account).get(account_id)
account = db.session.query(Account).get(account_id)
if not account:
return None, 404
@ -140,12 +133,10 @@ class AccountResource(Resource):
for k, v in kwargs.items():
setattr(account, k, v)
session.merge(account)
db.session.merge(account)
# Return account with solds.
return Account.query(
session
).filter(
return Account.query().filter(
Account.id == account_id
).one()

View File

@ -18,7 +18,7 @@ import dateutil.parser
from flask.ext.restful import Resource, fields, reqparse, marshal_with_field
from accountant import session_aware
from accountant import db
from .. import api_api
@ -62,54 +62,49 @@ range_parser.add_argument('end', type=lambda a: dateutil.parser.parse(a))
class OperationListResource(Resource):
@session_aware
@marshal_with_field(fields.List(Object(resource_fields)))
def get(self, session):
def get(self):
kwargs = range_parser.parse_args()
return Operation.query(
session,
begin=kwargs['begin'],
end=kwargs['end'],
).filter(
Operation.account_id == kwargs['account']
).all()
@session_aware
@marshal_with_field(Object(resource_fields))
def post(self, session):
def post(self):
kwargs = parser.parse_args()
operation = Operation(**kwargs)
session.add(operation)
db.session.add(operation)
return operation
class OperationResource(Resource):
@session_aware
@marshal_with_field(Object(resource_fields))
def get(self, operation_id, session):
def get(self, operation_id):
"""
Get operation.
"""
operation = session.query(Operation).get(operation_id)
operation = db.session.query(Operation).get(operation_id)
if not operation:
return None, 404
return operation
@session_aware
@marshal_with_field(Object(resource_fields))
def post(self, operation_id, session):
def post(self, operation_id):
kwargs = parser.parse_args()
assert (id not in kwargs or kwargs.id is None
or kwargs.id == operation_id)
operation = session.query(Operation).get(operation_id)
operation = db.session.query(Operation).get(operation_id)
if not operation:
return None, 404
@ -118,19 +113,18 @@ class OperationResource(Resource):
for k, v in kwargs.items():
setattr(operation, k, v)
session.merge(operation)
db.session.merge(operation)
return operation
@session_aware
@marshal_with_field(Object(resource_fields))
def delete(self, operation_id, session):
operation = session.query(Operation).get(operation_id)
def delete(self, operation_id):
operation = db.session.query(Operation).get(operation_id)
if not operation:
return None, 404
session.delete(operation)
db.session.delete(operation)
return operation
@ -143,12 +137,11 @@ category_resource_fields = {
class CategoriesResource(Resource):
@session_aware
@marshal_with_field(fields.List(Object(category_resource_fields)))
def get(self, session):
def get(self):
kwargs = range_parser.parse_args()
return Operation.get_categories_for_range(session, **kwargs).all()
return Operation.get_categories_for_range(**kwargs).all()
ohlc_resource_fields = {
@ -161,12 +154,11 @@ ohlc_resource_fields = {
class SoldsResource(Resource):
@session_aware
@marshal_with_field(fields.List(Object(ohlc_resource_fields)))
def get(self, session):
def get(self):
kwargs = range_parser.parse_args()
return Operation.get_ohlc_per_day_for_range(session, **kwargs).all()
return Operation.get_ohlc_per_day_for_range(**kwargs).all()
api_api.add_resource(OperationListResource, "/operations")

View File

@ -21,7 +21,7 @@ from flask.ext.restful import Resource, fields, reqparse, marshal_with_field
from sqlalchemy import true
from sqlalchemy.orm.exc import NoResultFound
from accountant import session_aware
from accountant import db
from ..models.scheduled_operations import ScheduledOperation
from ..models.operations import Operation
@ -60,23 +60,19 @@ get_parser.add_argument('account', type=int)
class ScheduledOperationListResource(Resource):
@session_aware
@marshal_with_field(fields.List(Object(resource_fields)))
def get(self, session):
def get(self):
"""
Get all scheduled operation for the account.
"""
kwargs = get_parser.parse_args()
return ScheduledOperation.query(
session
).filter(
return ScheduledOperation.query().filter(
ScheduledOperation.account_id == kwargs.account
).all()
@session_aware
@marshal_with_field(Object(resource_fields))
def post(self, session):
def post():
"""
Add a new scheduled operation.
"""
@ -84,41 +80,35 @@ class ScheduledOperationListResource(Resource):
scheduled_operation = ScheduledOperation(**kwargs)
session.add(scheduled_operation)
db.session.add(scheduled_operation)
scheduled_operation.reschedule(session)
scheduled_operation.reschedule()
session.flush()
db.session.flush()
return scheduled_operation, 201
class ScheduledOperationResource(Resource):
@session_aware
@marshal_with_field(Object(resource_fields))
def get(self, scheduled_operation_id, session):
def get(self, scheduled_operation_id):
"""
Get scheduled operation.
"""
try:
return ScheduledOperation.query(
session
).filter(
return ScheduledOperation.query().filter(
ScheduledOperation.id == scheduled_operation_id
).one()
except NoResultFound:
return None, 404
@session_aware
@marshal_with_field(Object(resource_fields))
def delete(self, scheduled_operation_id, session):
def delete(self, scheduled_operation_id):
"""
Delete a scheduled operation.
"""
try:
scheduled_operation = ScheduledOperation.query(
session
).filter(
scheduled_operation = ScheduledOperation.query().filter(
ScheduledOperation.id == scheduled_operation_id
).one()
except NoResultFound:
@ -137,13 +127,12 @@ class ScheduledOperationResource(Resource):
# Delete unconfirmed operations
operations = scheduled_operation.operations.delete()
session.delete(scheduled_operation)
db.session.delete(scheduled_operation)
return None, 204
@session_aware
@marshal_with_field(Object(resource_fields))
def post(self, scheduled_operation_id, session):
def post(self, scheduled_operation_id):
"""
Update a scheduled operation.
"""
@ -153,9 +142,7 @@ class ScheduledOperationResource(Resource):
or kwargs.id == scheduled_operation_id)
try:
scheduled_operation = ScheduledOperation.query(
session
).filter(
scheduled_operation = ScheduledOperation.query().filter(
ScheduledOperation.id == scheduled_operation_id
).one()
except NoResultFound:
@ -165,11 +152,11 @@ class ScheduledOperationResource(Resource):
for k, v in kwargs.items():
setattr(scheduled_operation, k, v)
session.merge(scheduled_operation)
db.session.merge(scheduled_operation)
scheduled_operation.reschedule(session)
scheduled_operation.reschedule()
session.flush()
db.session.flush()
return scheduled_operation