Simplify session management.

This commit is contained in:
Alexis Lahouze 2015-11-24 21:31:24 +01:00
parent 5b43cd763e
commit 762efc64c8
8 changed files with 70 additions and 124 deletions

View File

@ -46,30 +46,6 @@ app.debug = config.debug
db = SQLAlchemy(app) db = SQLAlchemy(app)
@contextmanager
def session_scope():
from accountant import db
session = db.session
try:
yield session
session.commit()
except:
session.rollback()
raise
finally:
session.close()
def session_aware(f):
def wrapper(*args, **kwargs):
with session_scope() as session:
kwargs['session'] = session
return f(*args, **kwargs)
return wrapper
# Must be after db declaration because the blueprints may need it. # Must be after db declaration because the blueprints may need it.
from .api import api from .api import api
from .frontend import frontend, frontend_js, frontend_css from .frontend import frontend, frontend_js, frontend_css

View File

@ -32,8 +32,8 @@ class Account(db.Model):
self.authorized_overdraft = authorized_overdraft self.authorized_overdraft = authorized_overdraft
@classmethod @classmethod
def query(cls, session, begin=None, end=None): def query(cls, begin=None, end=None):
status_query = session.query( status_query = db.session.query(
Operation.account_id, Operation.account_id,
func.sum(Operation.value).label("future"), func.sum(Operation.value).label("future"),
func.sum( func.sum(
@ -55,7 +55,7 @@ class Account(db.Model):
).subquery() ).subquery()
if begin and end: if begin and end:
balance_query = session.query( balance_query = db.session.query(
Operation.account_id, Operation.account_id,
func.sum( func.sum(
case( case(
@ -77,7 +77,7 @@ class Account(db.Model):
Operation.account_id Operation.account_id
).subquery() ).subquery()
query = session.query( query = db.session.query(
cls.id, cls.id,
cls.name, cls.name,
cls.authorized_overdraft, cls.authorized_overdraft,
@ -93,7 +93,7 @@ class Account(db.Model):
balance_query, balance_query.c.account_id == cls.id balance_query, balance_query.c.account_id == cls.id
) )
else: else:
query = session.query( query = db.session.query(
cls.id, cls.id,
cls.name, cls.name,
cls.authorized_overdraft, cls.authorized_overdraft,

View File

@ -102,15 +102,15 @@ class Operation(db.Model):
self.canceled = canceled self.canceled = canceled
@classmethod @classmethod
def query(cls, session, begin=None, end=None): def query(cls, begin=None, end=None):
# We have to use a join because the sold is not computed from the # We have to use a join because the sold is not computed from the
# begining. # begining.
base_query = session.query( base_query = db.session.query(
cls.id, cls.id,
cls.sold cls.sold
).subquery() ).subquery()
query = session.query( query = db.session.query(
cls.id, cls.id,
cls.operation_date, cls.operation_date,
cls.label, cls.label,
@ -139,13 +139,13 @@ class Operation(db.Model):
return query return query
@classmethod @classmethod
def get_categories_for_range(cls, session, account, begin, end): def get_categories_for_range(cls, account, begin, end):
if isinstance(account, int) or isinstance(account, str): if isinstance(account, int) or isinstance(account, str):
account_id = account account_id = account
else: else:
account_id = account.id account_id = account.id
query = session.query( query = db.session.query(
cls.category, cls.category,
func.sum( func.sum(
case([(func.sign(cls.value) == -1, cls.value)], else_=0) case([(func.sign(cls.value) == -1, cls.value)], else_=0)
@ -168,7 +168,7 @@ class Operation(db.Model):
return query return query
@classmethod @classmethod
def get_ohlc_per_day_for_range(cls, session, account, begin=None, end=None): def get_ohlc_per_day_for_range(cls, account, begin=None, end=None):
if isinstance(account, int) or isinstance(account, str): if isinstance(account, int) or isinstance(account, str):
account_id = account account_id = account
else: else:
@ -182,7 +182,7 @@ class Operation(db.Model):
previous = sold - cls.value previous = sold - cls.value
subquery = session.query( subquery = db.session.query(
cls.operation_date, cls.operation_date,
sold.label("sold"), sold.label("sold"),
previous.label("previous") previous.label("previous")
@ -191,7 +191,7 @@ class Operation(db.Model):
cls.canceled == false() cls.canceled == false()
).subquery() ).subquery()
query = session.query( query = db.session.query(
subquery.c.operation_date, subquery.c.operation_date,
func.first_value(subquery.c.previous).over( func.first_value(subquery.c.previous).over(
partition_by=subquery.c.operation_date partition_by=subquery.c.operation_date

View File

@ -62,8 +62,8 @@ class ScheduledOperation(db.Model):
self.category = category self.category = category
@classmethod @classmethod
def query(cls, session): def query(cls):
return session.query( return db.session.query(
cls cls
).order_by( ).order_by(
desc(cls.day), desc(cls.day),
@ -71,9 +71,9 @@ class ScheduledOperation(db.Model):
cls.label, cls.label,
) )
def reschedule(self, session): def reschedule(self):
# 1) delete unconfirmed operations for this account. # 1) delete unconfirmed operations for this account.
session.query( db.session.query(
Operation Operation
).filter( ).filter(
Operation.scheduled_operation_id == self.id, Operation.scheduled_operation_id == self.id,
@ -106,7 +106,7 @@ class ScheduledOperation(db.Model):
# Search if a close operation exists. # Search if a close operation exists.
if not session.query( if not db.session.query(
Operation Operation
).filter( ).filter(
Operation.account_id == self.account_id, Operation.account_id == self.account_id,
@ -127,4 +127,4 @@ class ScheduledOperation(db.Model):
canceled=False canceled=False
) )
session.add(operation) db.session.add(operation)

View File

@ -55,7 +55,7 @@ class User(UserMixin, db.Model):
return serializer.dumps({'id': self.id}) return serializer.dumps({'id': self.id})
@classmethod @classmethod
def verify_auth_token(cls, session, token): def verify_auth_token(cls, token):
serializer = Serializer(app.config['SECRET_KEY']) serializer = Serializer(app.config['SECRET_KEY'])
try: try:
@ -65,5 +65,5 @@ class User(UserMixin, db.Model):
except BadSignature: except BadSignature:
return None return None
user = cls.query(session).get(data['id']) user = cls.query().get(data['id'])
return user return user

View File

@ -20,7 +20,7 @@ from flask.ext.restful import Resource, fields, reqparse, marshal_with_field
from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.orm.exc import NoResultFound
from accountant import session_aware from accountant import db
from .. import api_api from .. import api_api
@ -54,17 +54,15 @@ date_parser.add_argument('end',
class AccountListResource(Resource): class AccountListResource(Resource):
@session_aware
@marshal_with_field(fields.List(Object(resource_fields))) @marshal_with_field(fields.List(Object(resource_fields)))
def get(self, session): def get(self):
""" """
Returns accounts with their balances. Returns accounts with their balances.
""" """
return Account.query(session).all(), 200 return Account.query().all(), 200
@session_aware
@marshal_with_field(Object(resource_fields)) @marshal_with_field(Object(resource_fields))
def post(self, session): def post(self):
""" """
Create a new account. Create a new account.
""" """
@ -72,15 +70,13 @@ class AccountListResource(Resource):
account = Account(**kwargs) account = Account(**kwargs)
session.add(account) db.session.add(account)
# Flush session to have id in account. # Flush session to have id in account.
session.flush() db.session.flush()
# Return account with solds. # Return account with solds.
return Account.query( return Account.query().filter(
session
).filter(
Account.id == account.id Account.id == account.id
).one(), 201 ).one(), 201
@ -92,9 +88,8 @@ class AccountListResource(Resource):
class AccountResource(Resource): class AccountResource(Resource):
@session_aware
@marshal_with_field(Object(resource_fields)) @marshal_with_field(Object(resource_fields))
def get(self, account_id, session): def get(self, account_id):
""" """
Get account. Get account.
""" """
@ -102,36 +97,34 @@ class AccountResource(Resource):
try: try:
return Account.query( return Account.query(
session, **kwargs **kwargs
).filter( ).filter(
Account.id == account_id Account.id == account_id
).one() ).one()
except NoResultFound: except NoResultFound:
return None, 404 return None, 404
@session_aware
@marshal_with_field(Object(resource_fields)) @marshal_with_field(Object(resource_fields))
def delete(self, account_id, session): def delete(self, account_id):
# Need to get the object to update it. # Need to get the object to update it.
account = session.query(Account).get(account_id) account = db.session.query(Account).get(account_id)
if not account: if not account:
return None, 404 return None, 404
session.delete(account) db.session.delete(account)
return None, 204 return None, 204
@session_aware
@marshal_with_field(Object(resource_fields)) @marshal_with_field(Object(resource_fields))
def post(self, account_id, session): def post(self, account_id):
kwargs = parser.parse_args() kwargs = parser.parse_args()
assert (id not in kwargs or kwargs.id is None assert (id not in kwargs or kwargs.id is None
or kwargs.id == account_id) or kwargs.id == account_id)
# Need to get the object to update it. # Need to get the object to update it.
account = session.query(Account).get(account_id) account = db.session.query(Account).get(account_id)
if not account: if not account:
return None, 404 return None, 404
@ -140,12 +133,10 @@ class AccountResource(Resource):
for k, v in kwargs.items(): for k, v in kwargs.items():
setattr(account, k, v) setattr(account, k, v)
session.merge(account) db.session.merge(account)
# Return account with solds. # Return account with solds.
return Account.query( return Account.query().filter(
session
).filter(
Account.id == account_id Account.id == account_id
).one() ).one()

View File

@ -18,7 +18,7 @@ import dateutil.parser
from flask.ext.restful import Resource, fields, reqparse, marshal_with_field from flask.ext.restful import Resource, fields, reqparse, marshal_with_field
from accountant import session_aware from accountant import db
from .. import api_api from .. import api_api
@ -62,54 +62,49 @@ range_parser.add_argument('end', type=lambda a: dateutil.parser.parse(a))
class OperationListResource(Resource): class OperationListResource(Resource):
@session_aware
@marshal_with_field(fields.List(Object(resource_fields))) @marshal_with_field(fields.List(Object(resource_fields)))
def get(self, session): def get(self):
kwargs = range_parser.parse_args() kwargs = range_parser.parse_args()
return Operation.query( return Operation.query(
session,
begin=kwargs['begin'], begin=kwargs['begin'],
end=kwargs['end'], end=kwargs['end'],
).filter( ).filter(
Operation.account_id == kwargs['account'] Operation.account_id == kwargs['account']
).all() ).all()
@session_aware
@marshal_with_field(Object(resource_fields)) @marshal_with_field(Object(resource_fields))
def post(self, session): def post(self):
kwargs = parser.parse_args() kwargs = parser.parse_args()
operation = Operation(**kwargs) operation = Operation(**kwargs)
session.add(operation) db.session.add(operation)
return operation return operation
class OperationResource(Resource): class OperationResource(Resource):
@session_aware
@marshal_with_field(Object(resource_fields)) @marshal_with_field(Object(resource_fields))
def get(self, operation_id, session): def get(self, operation_id):
""" """
Get operation. Get operation.
""" """
operation = session.query(Operation).get(operation_id) operation = db.session.query(Operation).get(operation_id)
if not operation: if not operation:
return None, 404 return None, 404
return operation return operation
@session_aware
@marshal_with_field(Object(resource_fields)) @marshal_with_field(Object(resource_fields))
def post(self, operation_id, session): def post(self, operation_id):
kwargs = parser.parse_args() kwargs = parser.parse_args()
assert (id not in kwargs or kwargs.id is None assert (id not in kwargs or kwargs.id is None
or kwargs.id == operation_id) or kwargs.id == operation_id)
operation = session.query(Operation).get(operation_id) operation = db.session.query(Operation).get(operation_id)
if not operation: if not operation:
return None, 404 return None, 404
@ -118,19 +113,18 @@ class OperationResource(Resource):
for k, v in kwargs.items(): for k, v in kwargs.items():
setattr(operation, k, v) setattr(operation, k, v)
session.merge(operation) db.session.merge(operation)
return operation return operation
@session_aware
@marshal_with_field(Object(resource_fields)) @marshal_with_field(Object(resource_fields))
def delete(self, operation_id, session): def delete(self, operation_id):
operation = session.query(Operation).get(operation_id) operation = db.session.query(Operation).get(operation_id)
if not operation: if not operation:
return None, 404 return None, 404
session.delete(operation) db.session.delete(operation)
return operation return operation
@ -143,12 +137,11 @@ category_resource_fields = {
class CategoriesResource(Resource): class CategoriesResource(Resource):
@session_aware
@marshal_with_field(fields.List(Object(category_resource_fields))) @marshal_with_field(fields.List(Object(category_resource_fields)))
def get(self, session): def get(self):
kwargs = range_parser.parse_args() kwargs = range_parser.parse_args()
return Operation.get_categories_for_range(session, **kwargs).all() return Operation.get_categories_for_range(**kwargs).all()
ohlc_resource_fields = { ohlc_resource_fields = {
@ -161,12 +154,11 @@ ohlc_resource_fields = {
class SoldsResource(Resource): class SoldsResource(Resource):
@session_aware
@marshal_with_field(fields.List(Object(ohlc_resource_fields))) @marshal_with_field(fields.List(Object(ohlc_resource_fields)))
def get(self, session): def get(self):
kwargs = range_parser.parse_args() kwargs = range_parser.parse_args()
return Operation.get_ohlc_per_day_for_range(session, **kwargs).all() return Operation.get_ohlc_per_day_for_range(**kwargs).all()
api_api.add_resource(OperationListResource, "/operations") api_api.add_resource(OperationListResource, "/operations")

View File

@ -21,7 +21,7 @@ from flask.ext.restful import Resource, fields, reqparse, marshal_with_field
from sqlalchemy import true from sqlalchemy import true
from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.orm.exc import NoResultFound
from accountant import session_aware from accountant import db
from ..models.scheduled_operations import ScheduledOperation from ..models.scheduled_operations import ScheduledOperation
from ..models.operations import Operation from ..models.operations import Operation
@ -60,23 +60,19 @@ get_parser.add_argument('account', type=int)
class ScheduledOperationListResource(Resource): class ScheduledOperationListResource(Resource):
@session_aware
@marshal_with_field(fields.List(Object(resource_fields))) @marshal_with_field(fields.List(Object(resource_fields)))
def get(self, session): def get(self):
""" """
Get all scheduled operation for the account. Get all scheduled operation for the account.
""" """
kwargs = get_parser.parse_args() kwargs = get_parser.parse_args()
return ScheduledOperation.query( return ScheduledOperation.query().filter(
session
).filter(
ScheduledOperation.account_id == kwargs.account ScheduledOperation.account_id == kwargs.account
).all() ).all()
@session_aware
@marshal_with_field(Object(resource_fields)) @marshal_with_field(Object(resource_fields))
def post(self, session): def post():
""" """
Add a new scheduled operation. Add a new scheduled operation.
""" """
@ -84,41 +80,35 @@ class ScheduledOperationListResource(Resource):
scheduled_operation = ScheduledOperation(**kwargs) scheduled_operation = ScheduledOperation(**kwargs)
session.add(scheduled_operation) db.session.add(scheduled_operation)
scheduled_operation.reschedule(session) scheduled_operation.reschedule()
session.flush() db.session.flush()
return scheduled_operation, 201 return scheduled_operation, 201
class ScheduledOperationResource(Resource): class ScheduledOperationResource(Resource):
@session_aware
@marshal_with_field(Object(resource_fields)) @marshal_with_field(Object(resource_fields))
def get(self, scheduled_operation_id, session): def get(self, scheduled_operation_id):
""" """
Get scheduled operation. Get scheduled operation.
""" """
try: try:
return ScheduledOperation.query( return ScheduledOperation.query().filter(
session
).filter(
ScheduledOperation.id == scheduled_operation_id ScheduledOperation.id == scheduled_operation_id
).one() ).one()
except NoResultFound: except NoResultFound:
return None, 404 return None, 404
@session_aware
@marshal_with_field(Object(resource_fields)) @marshal_with_field(Object(resource_fields))
def delete(self, scheduled_operation_id, session): def delete(self, scheduled_operation_id):
""" """
Delete a scheduled operation. Delete a scheduled operation.
""" """
try: try:
scheduled_operation = ScheduledOperation.query( scheduled_operation = ScheduledOperation.query().filter(
session
).filter(
ScheduledOperation.id == scheduled_operation_id ScheduledOperation.id == scheduled_operation_id
).one() ).one()
except NoResultFound: except NoResultFound:
@ -137,13 +127,12 @@ class ScheduledOperationResource(Resource):
# Delete unconfirmed operations # Delete unconfirmed operations
operations = scheduled_operation.operations.delete() operations = scheduled_operation.operations.delete()
session.delete(scheduled_operation) db.session.delete(scheduled_operation)
return None, 204 return None, 204
@session_aware
@marshal_with_field(Object(resource_fields)) @marshal_with_field(Object(resource_fields))
def post(self, scheduled_operation_id, session): def post(self, scheduled_operation_id):
""" """
Update a scheduled operation. Update a scheduled operation.
""" """
@ -153,9 +142,7 @@ class ScheduledOperationResource(Resource):
or kwargs.id == scheduled_operation_id) or kwargs.id == scheduled_operation_id)
try: try:
scheduled_operation = ScheduledOperation.query( scheduled_operation = ScheduledOperation.query().filter(
session
).filter(
ScheduledOperation.id == scheduled_operation_id ScheduledOperation.id == scheduled_operation_id
).one() ).one()
except NoResultFound: except NoResultFound:
@ -165,11 +152,11 @@ class ScheduledOperationResource(Resource):
for k, v in kwargs.items(): for k, v in kwargs.items():
setattr(scheduled_operation, k, v) setattr(scheduled_operation, k, v)
session.merge(scheduled_operation) db.session.merge(scheduled_operation)
scheduled_operation.reschedule(session) scheduled_operation.reschedule()
session.flush() db.session.flush()
return scheduled_operation return scheduled_operation