From 762efc64c8ebbcaa76996fd4af146d94a70f0577 Mon Sep 17 00:00:00 2001 From: Alexis Lahouze Date: Tue, 24 Nov 2015 21:31:24 +0100 Subject: [PATCH] Simplify session management. --- accountant/__init__.py | 24 ---------- accountant/api/models/accounts.py | 10 ++-- accountant/api/models/operations.py | 16 +++---- accountant/api/models/scheduled_operations.py | 12 ++--- accountant/api/models/users.py | 4 +- accountant/api/views/accounts.py | 41 +++++++--------- accountant/api/views/operations.py | 40 +++++++--------- accountant/api/views/scheduled_operations.py | 47 +++++++------------ 8 files changed, 70 insertions(+), 124 deletions(-) diff --git a/accountant/__init__.py b/accountant/__init__.py index c4500f2..757bd73 100644 --- a/accountant/__init__.py +++ b/accountant/__init__.py @@ -46,30 +46,6 @@ app.debug = config.debug db = SQLAlchemy(app) -@contextmanager -def session_scope(): - from accountant import db - session = db.session - - try: - yield session - session.commit() - except: - session.rollback() - raise - finally: - session.close() - - -def session_aware(f): - def wrapper(*args, **kwargs): - with session_scope() as session: - kwargs['session'] = session - - return f(*args, **kwargs) - return wrapper - - # Must be after db declaration because the blueprints may need it. from .api import api from .frontend import frontend, frontend_js, frontend_css diff --git a/accountant/api/models/accounts.py b/accountant/api/models/accounts.py index ee01ff0..a0e1b4b 100644 --- a/accountant/api/models/accounts.py +++ b/accountant/api/models/accounts.py @@ -32,8 +32,8 @@ class Account(db.Model): self.authorized_overdraft = authorized_overdraft @classmethod - def query(cls, session, begin=None, end=None): - status_query = session.query( + def query(cls, begin=None, end=None): + status_query = db.session.query( Operation.account_id, func.sum(Operation.value).label("future"), func.sum( @@ -55,7 +55,7 @@ class Account(db.Model): ).subquery() if begin and end: - balance_query = session.query( + balance_query = db.session.query( Operation.account_id, func.sum( case( @@ -77,7 +77,7 @@ class Account(db.Model): Operation.account_id ).subquery() - query = session.query( + query = db.session.query( cls.id, cls.name, cls.authorized_overdraft, @@ -93,7 +93,7 @@ class Account(db.Model): balance_query, balance_query.c.account_id == cls.id ) else: - query = session.query( + query = db.session.query( cls.id, cls.name, cls.authorized_overdraft, diff --git a/accountant/api/models/operations.py b/accountant/api/models/operations.py index fa53a72..26b1312 100644 --- a/accountant/api/models/operations.py +++ b/accountant/api/models/operations.py @@ -102,15 +102,15 @@ class Operation(db.Model): self.canceled = canceled @classmethod - def query(cls, session, begin=None, end=None): + def query(cls, begin=None, end=None): # We have to use a join because the sold is not computed from the # begining. - base_query = session.query( + base_query = db.session.query( cls.id, cls.sold ).subquery() - query = session.query( + query = db.session.query( cls.id, cls.operation_date, cls.label, @@ -139,13 +139,13 @@ class Operation(db.Model): return query @classmethod - def get_categories_for_range(cls, session, account, begin, end): + def get_categories_for_range(cls, account, begin, end): if isinstance(account, int) or isinstance(account, str): account_id = account else: account_id = account.id - query = session.query( + query = db.session.query( cls.category, func.sum( case([(func.sign(cls.value) == -1, cls.value)], else_=0) @@ -168,7 +168,7 @@ class Operation(db.Model): return query @classmethod - def get_ohlc_per_day_for_range(cls, session, account, begin=None, end=None): + def get_ohlc_per_day_for_range(cls, account, begin=None, end=None): if isinstance(account, int) or isinstance(account, str): account_id = account else: @@ -182,7 +182,7 @@ class Operation(db.Model): previous = sold - cls.value - subquery = session.query( + subquery = db.session.query( cls.operation_date, sold.label("sold"), previous.label("previous") @@ -191,7 +191,7 @@ class Operation(db.Model): cls.canceled == false() ).subquery() - query = session.query( + query = db.session.query( subquery.c.operation_date, func.first_value(subquery.c.previous).over( partition_by=subquery.c.operation_date diff --git a/accountant/api/models/scheduled_operations.py b/accountant/api/models/scheduled_operations.py index a1db35e..7becf88 100644 --- a/accountant/api/models/scheduled_operations.py +++ b/accountant/api/models/scheduled_operations.py @@ -62,8 +62,8 @@ class ScheduledOperation(db.Model): self.category = category @classmethod - def query(cls, session): - return session.query( + def query(cls): + return db.session.query( cls ).order_by( desc(cls.day), @@ -71,9 +71,9 @@ class ScheduledOperation(db.Model): cls.label, ) - def reschedule(self, session): + def reschedule(self): # 1) delete unconfirmed operations for this account. - session.query( + db.session.query( Operation ).filter( Operation.scheduled_operation_id == self.id, @@ -106,7 +106,7 @@ class ScheduledOperation(db.Model): # Search if a close operation exists. - if not session.query( + if not db.session.query( Operation ).filter( Operation.account_id == self.account_id, @@ -127,4 +127,4 @@ class ScheduledOperation(db.Model): canceled=False ) - session.add(operation) + db.session.add(operation) diff --git a/accountant/api/models/users.py b/accountant/api/models/users.py index 5235647..2cafe42 100644 --- a/accountant/api/models/users.py +++ b/accountant/api/models/users.py @@ -55,7 +55,7 @@ class User(UserMixin, db.Model): return serializer.dumps({'id': self.id}) @classmethod - def verify_auth_token(cls, session, token): + def verify_auth_token(cls, token): serializer = Serializer(app.config['SECRET_KEY']) try: @@ -65,5 +65,5 @@ class User(UserMixin, db.Model): except BadSignature: return None - user = cls.query(session).get(data['id']) + user = cls.query().get(data['id']) return user diff --git a/accountant/api/views/accounts.py b/accountant/api/views/accounts.py index f37dce6..6ad945b 100644 --- a/accountant/api/views/accounts.py +++ b/accountant/api/views/accounts.py @@ -20,7 +20,7 @@ from flask.ext.restful import Resource, fields, reqparse, marshal_with_field from sqlalchemy.orm.exc import NoResultFound -from accountant import session_aware +from accountant import db from .. import api_api @@ -54,17 +54,15 @@ date_parser.add_argument('end', class AccountListResource(Resource): - @session_aware @marshal_with_field(fields.List(Object(resource_fields))) - def get(self, session): + def get(self): """ Returns accounts with their balances. """ - return Account.query(session).all(), 200 + return Account.query().all(), 200 - @session_aware @marshal_with_field(Object(resource_fields)) - def post(self, session): + def post(self): """ Create a new account. """ @@ -72,15 +70,13 @@ class AccountListResource(Resource): account = Account(**kwargs) - session.add(account) + db.session.add(account) # Flush session to have id in account. - session.flush() + db.session.flush() # Return account with solds. - return Account.query( - session - ).filter( + return Account.query().filter( Account.id == account.id ).one(), 201 @@ -92,9 +88,8 @@ class AccountListResource(Resource): class AccountResource(Resource): - @session_aware @marshal_with_field(Object(resource_fields)) - def get(self, account_id, session): + def get(self, account_id): """ Get account. """ @@ -102,36 +97,34 @@ class AccountResource(Resource): try: return Account.query( - session, **kwargs + **kwargs ).filter( Account.id == account_id ).one() except NoResultFound: return None, 404 - @session_aware @marshal_with_field(Object(resource_fields)) - def delete(self, account_id, session): + def delete(self, account_id): # Need to get the object to update it. - account = session.query(Account).get(account_id) + account = db.session.query(Account).get(account_id) if not account: return None, 404 - session.delete(account) + db.session.delete(account) return None, 204 - @session_aware @marshal_with_field(Object(resource_fields)) - def post(self, account_id, session): + def post(self, account_id): kwargs = parser.parse_args() assert (id not in kwargs or kwargs.id is None or kwargs.id == account_id) # Need to get the object to update it. - account = session.query(Account).get(account_id) + account = db.session.query(Account).get(account_id) if not account: return None, 404 @@ -140,12 +133,10 @@ class AccountResource(Resource): for k, v in kwargs.items(): setattr(account, k, v) - session.merge(account) + db.session.merge(account) # Return account with solds. - return Account.query( - session - ).filter( + return Account.query().filter( Account.id == account_id ).one() diff --git a/accountant/api/views/operations.py b/accountant/api/views/operations.py index b3a4803..63c269b 100644 --- a/accountant/api/views/operations.py +++ b/accountant/api/views/operations.py @@ -18,7 +18,7 @@ import dateutil.parser from flask.ext.restful import Resource, fields, reqparse, marshal_with_field -from accountant import session_aware +from accountant import db from .. import api_api @@ -62,54 +62,49 @@ range_parser.add_argument('end', type=lambda a: dateutil.parser.parse(a)) class OperationListResource(Resource): - @session_aware @marshal_with_field(fields.List(Object(resource_fields))) - def get(self, session): + def get(self): kwargs = range_parser.parse_args() return Operation.query( - session, begin=kwargs['begin'], end=kwargs['end'], ).filter( Operation.account_id == kwargs['account'] ).all() - @session_aware @marshal_with_field(Object(resource_fields)) - def post(self, session): + def post(self): kwargs = parser.parse_args() operation = Operation(**kwargs) - session.add(operation) + db.session.add(operation) return operation class OperationResource(Resource): - @session_aware @marshal_with_field(Object(resource_fields)) - def get(self, operation_id, session): + def get(self, operation_id): """ Get operation. """ - operation = session.query(Operation).get(operation_id) + operation = db.session.query(Operation).get(operation_id) if not operation: return None, 404 return operation - @session_aware @marshal_with_field(Object(resource_fields)) - def post(self, operation_id, session): + def post(self, operation_id): kwargs = parser.parse_args() assert (id not in kwargs or kwargs.id is None or kwargs.id == operation_id) - operation = session.query(Operation).get(operation_id) + operation = db.session.query(Operation).get(operation_id) if not operation: return None, 404 @@ -118,19 +113,18 @@ class OperationResource(Resource): for k, v in kwargs.items(): setattr(operation, k, v) - session.merge(operation) + db.session.merge(operation) return operation - @session_aware @marshal_with_field(Object(resource_fields)) - def delete(self, operation_id, session): - operation = session.query(Operation).get(operation_id) + def delete(self, operation_id): + operation = db.session.query(Operation).get(operation_id) if not operation: return None, 404 - session.delete(operation) + db.session.delete(operation) return operation @@ -143,12 +137,11 @@ category_resource_fields = { class CategoriesResource(Resource): - @session_aware @marshal_with_field(fields.List(Object(category_resource_fields))) - def get(self, session): + def get(self): kwargs = range_parser.parse_args() - return Operation.get_categories_for_range(session, **kwargs).all() + return Operation.get_categories_for_range(**kwargs).all() ohlc_resource_fields = { @@ -161,12 +154,11 @@ ohlc_resource_fields = { class SoldsResource(Resource): - @session_aware @marshal_with_field(fields.List(Object(ohlc_resource_fields))) - def get(self, session): + def get(self): kwargs = range_parser.parse_args() - return Operation.get_ohlc_per_day_for_range(session, **kwargs).all() + return Operation.get_ohlc_per_day_for_range(**kwargs).all() api_api.add_resource(OperationListResource, "/operations") diff --git a/accountant/api/views/scheduled_operations.py b/accountant/api/views/scheduled_operations.py index 50146f0..3b46d26 100644 --- a/accountant/api/views/scheduled_operations.py +++ b/accountant/api/views/scheduled_operations.py @@ -21,7 +21,7 @@ from flask.ext.restful import Resource, fields, reqparse, marshal_with_field from sqlalchemy import true from sqlalchemy.orm.exc import NoResultFound -from accountant import session_aware +from accountant import db from ..models.scheduled_operations import ScheduledOperation from ..models.operations import Operation @@ -60,23 +60,19 @@ get_parser.add_argument('account', type=int) class ScheduledOperationListResource(Resource): - @session_aware @marshal_with_field(fields.List(Object(resource_fields))) - def get(self, session): + def get(self): """ Get all scheduled operation for the account. """ kwargs = get_parser.parse_args() - return ScheduledOperation.query( - session - ).filter( + return ScheduledOperation.query().filter( ScheduledOperation.account_id == kwargs.account ).all() - @session_aware @marshal_with_field(Object(resource_fields)) - def post(self, session): + def post(): """ Add a new scheduled operation. """ @@ -84,41 +80,35 @@ class ScheduledOperationListResource(Resource): scheduled_operation = ScheduledOperation(**kwargs) - session.add(scheduled_operation) + db.session.add(scheduled_operation) - scheduled_operation.reschedule(session) + scheduled_operation.reschedule() - session.flush() + db.session.flush() return scheduled_operation, 201 class ScheduledOperationResource(Resource): - @session_aware @marshal_with_field(Object(resource_fields)) - def get(self, scheduled_operation_id, session): + def get(self, scheduled_operation_id): """ Get scheduled operation. """ try: - return ScheduledOperation.query( - session - ).filter( + return ScheduledOperation.query().filter( ScheduledOperation.id == scheduled_operation_id ).one() except NoResultFound: return None, 404 - @session_aware @marshal_with_field(Object(resource_fields)) - def delete(self, scheduled_operation_id, session): + def delete(self, scheduled_operation_id): """ Delete a scheduled operation. """ try: - scheduled_operation = ScheduledOperation.query( - session - ).filter( + scheduled_operation = ScheduledOperation.query().filter( ScheduledOperation.id == scheduled_operation_id ).one() except NoResultFound: @@ -137,13 +127,12 @@ class ScheduledOperationResource(Resource): # Delete unconfirmed operations operations = scheduled_operation.operations.delete() - session.delete(scheduled_operation) + db.session.delete(scheduled_operation) return None, 204 - @session_aware @marshal_with_field(Object(resource_fields)) - def post(self, scheduled_operation_id, session): + def post(self, scheduled_operation_id): """ Update a scheduled operation. """ @@ -153,9 +142,7 @@ class ScheduledOperationResource(Resource): or kwargs.id == scheduled_operation_id) try: - scheduled_operation = ScheduledOperation.query( - session - ).filter( + scheduled_operation = ScheduledOperation.query().filter( ScheduledOperation.id == scheduled_operation_id ).one() except NoResultFound: @@ -165,11 +152,11 @@ class ScheduledOperationResource(Resource): for k, v in kwargs.items(): setattr(scheduled_operation, k, v) - session.merge(scheduled_operation) + db.session.merge(scheduled_operation) - scheduled_operation.reschedule(session) + scheduled_operation.reschedule() - session.flush() + db.session.flush() return scheduled_operation