Simplify session management.
This commit is contained in:
parent
5b43cd763e
commit
762efc64c8
@ -46,30 +46,6 @@ app.debug = config.debug
|
||||
db = SQLAlchemy(app)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def session_scope():
|
||||
from accountant import db
|
||||
session = db.session
|
||||
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
def session_aware(f):
|
||||
def wrapper(*args, **kwargs):
|
||||
with session_scope() as session:
|
||||
kwargs['session'] = session
|
||||
|
||||
return f(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
# Must be after db declaration because the blueprints may need it.
|
||||
from .api import api
|
||||
from .frontend import frontend, frontend_js, frontend_css
|
||||
|
@ -32,8 +32,8 @@ class Account(db.Model):
|
||||
self.authorized_overdraft = authorized_overdraft
|
||||
|
||||
@classmethod
|
||||
def query(cls, session, begin=None, end=None):
|
||||
status_query = session.query(
|
||||
def query(cls, begin=None, end=None):
|
||||
status_query = db.session.query(
|
||||
Operation.account_id,
|
||||
func.sum(Operation.value).label("future"),
|
||||
func.sum(
|
||||
@ -55,7 +55,7 @@ class Account(db.Model):
|
||||
).subquery()
|
||||
|
||||
if begin and end:
|
||||
balance_query = session.query(
|
||||
balance_query = db.session.query(
|
||||
Operation.account_id,
|
||||
func.sum(
|
||||
case(
|
||||
@ -77,7 +77,7 @@ class Account(db.Model):
|
||||
Operation.account_id
|
||||
).subquery()
|
||||
|
||||
query = session.query(
|
||||
query = db.session.query(
|
||||
cls.id,
|
||||
cls.name,
|
||||
cls.authorized_overdraft,
|
||||
@ -93,7 +93,7 @@ class Account(db.Model):
|
||||
balance_query, balance_query.c.account_id == cls.id
|
||||
)
|
||||
else:
|
||||
query = session.query(
|
||||
query = db.session.query(
|
||||
cls.id,
|
||||
cls.name,
|
||||
cls.authorized_overdraft,
|
||||
|
@ -102,15 +102,15 @@ class Operation(db.Model):
|
||||
self.canceled = canceled
|
||||
|
||||
@classmethod
|
||||
def query(cls, session, begin=None, end=None):
|
||||
def query(cls, begin=None, end=None):
|
||||
# We have to use a join because the sold is not computed from the
|
||||
# begining.
|
||||
base_query = session.query(
|
||||
base_query = db.session.query(
|
||||
cls.id,
|
||||
cls.sold
|
||||
).subquery()
|
||||
|
||||
query = session.query(
|
||||
query = db.session.query(
|
||||
cls.id,
|
||||
cls.operation_date,
|
||||
cls.label,
|
||||
@ -139,13 +139,13 @@ class Operation(db.Model):
|
||||
return query
|
||||
|
||||
@classmethod
|
||||
def get_categories_for_range(cls, session, account, begin, end):
|
||||
def get_categories_for_range(cls, account, begin, end):
|
||||
if isinstance(account, int) or isinstance(account, str):
|
||||
account_id = account
|
||||
else:
|
||||
account_id = account.id
|
||||
|
||||
query = session.query(
|
||||
query = db.session.query(
|
||||
cls.category,
|
||||
func.sum(
|
||||
case([(func.sign(cls.value) == -1, cls.value)], else_=0)
|
||||
@ -168,7 +168,7 @@ class Operation(db.Model):
|
||||
return query
|
||||
|
||||
@classmethod
|
||||
def get_ohlc_per_day_for_range(cls, session, account, begin=None, end=None):
|
||||
def get_ohlc_per_day_for_range(cls, account, begin=None, end=None):
|
||||
if isinstance(account, int) or isinstance(account, str):
|
||||
account_id = account
|
||||
else:
|
||||
@ -182,7 +182,7 @@ class Operation(db.Model):
|
||||
|
||||
previous = sold - cls.value
|
||||
|
||||
subquery = session.query(
|
||||
subquery = db.session.query(
|
||||
cls.operation_date,
|
||||
sold.label("sold"),
|
||||
previous.label("previous")
|
||||
@ -191,7 +191,7 @@ class Operation(db.Model):
|
||||
cls.canceled == false()
|
||||
).subquery()
|
||||
|
||||
query = session.query(
|
||||
query = db.session.query(
|
||||
subquery.c.operation_date,
|
||||
func.first_value(subquery.c.previous).over(
|
||||
partition_by=subquery.c.operation_date
|
||||
|
@ -62,8 +62,8 @@ class ScheduledOperation(db.Model):
|
||||
self.category = category
|
||||
|
||||
@classmethod
|
||||
def query(cls, session):
|
||||
return session.query(
|
||||
def query(cls):
|
||||
return db.session.query(
|
||||
cls
|
||||
).order_by(
|
||||
desc(cls.day),
|
||||
@ -71,9 +71,9 @@ class ScheduledOperation(db.Model):
|
||||
cls.label,
|
||||
)
|
||||
|
||||
def reschedule(self, session):
|
||||
def reschedule(self):
|
||||
# 1) delete unconfirmed operations for this account.
|
||||
session.query(
|
||||
db.session.query(
|
||||
Operation
|
||||
).filter(
|
||||
Operation.scheduled_operation_id == self.id,
|
||||
@ -106,7 +106,7 @@ class ScheduledOperation(db.Model):
|
||||
|
||||
# Search if a close operation exists.
|
||||
|
||||
if not session.query(
|
||||
if not db.session.query(
|
||||
Operation
|
||||
).filter(
|
||||
Operation.account_id == self.account_id,
|
||||
@ -127,4 +127,4 @@ class ScheduledOperation(db.Model):
|
||||
canceled=False
|
||||
)
|
||||
|
||||
session.add(operation)
|
||||
db.session.add(operation)
|
||||
|
@ -55,7 +55,7 @@ class User(UserMixin, db.Model):
|
||||
return serializer.dumps({'id': self.id})
|
||||
|
||||
@classmethod
|
||||
def verify_auth_token(cls, session, token):
|
||||
def verify_auth_token(cls, token):
|
||||
serializer = Serializer(app.config['SECRET_KEY'])
|
||||
|
||||
try:
|
||||
@ -65,5 +65,5 @@ class User(UserMixin, db.Model):
|
||||
except BadSignature:
|
||||
return None
|
||||
|
||||
user = cls.query(session).get(data['id'])
|
||||
user = cls.query().get(data['id'])
|
||||
return user
|
||||
|
@ -20,7 +20,7 @@ from flask.ext.restful import Resource, fields, reqparse, marshal_with_field
|
||||
|
||||
from sqlalchemy.orm.exc import NoResultFound
|
||||
|
||||
from accountant import session_aware
|
||||
from accountant import db
|
||||
|
||||
from .. import api_api
|
||||
|
||||
@ -54,17 +54,15 @@ date_parser.add_argument('end',
|
||||
|
||||
|
||||
class AccountListResource(Resource):
|
||||
@session_aware
|
||||
@marshal_with_field(fields.List(Object(resource_fields)))
|
||||
def get(self, session):
|
||||
def get(self):
|
||||
"""
|
||||
Returns accounts with their balances.
|
||||
"""
|
||||
return Account.query(session).all(), 200
|
||||
return Account.query().all(), 200
|
||||
|
||||
@session_aware
|
||||
@marshal_with_field(Object(resource_fields))
|
||||
def post(self, session):
|
||||
def post(self):
|
||||
"""
|
||||
Create a new account.
|
||||
"""
|
||||
@ -72,15 +70,13 @@ class AccountListResource(Resource):
|
||||
|
||||
account = Account(**kwargs)
|
||||
|
||||
session.add(account)
|
||||
db.session.add(account)
|
||||
|
||||
# Flush session to have id in account.
|
||||
session.flush()
|
||||
db.session.flush()
|
||||
|
||||
# Return account with solds.
|
||||
return Account.query(
|
||||
session
|
||||
).filter(
|
||||
return Account.query().filter(
|
||||
Account.id == account.id
|
||||
).one(), 201
|
||||
|
||||
@ -92,9 +88,8 @@ class AccountListResource(Resource):
|
||||
|
||||
|
||||
class AccountResource(Resource):
|
||||
@session_aware
|
||||
@marshal_with_field(Object(resource_fields))
|
||||
def get(self, account_id, session):
|
||||
def get(self, account_id):
|
||||
"""
|
||||
Get account.
|
||||
"""
|
||||
@ -102,36 +97,34 @@ class AccountResource(Resource):
|
||||
|
||||
try:
|
||||
return Account.query(
|
||||
session, **kwargs
|
||||
**kwargs
|
||||
).filter(
|
||||
Account.id == account_id
|
||||
).one()
|
||||
except NoResultFound:
|
||||
return None, 404
|
||||
|
||||
@session_aware
|
||||
@marshal_with_field(Object(resource_fields))
|
||||
def delete(self, account_id, session):
|
||||
def delete(self, account_id):
|
||||
# Need to get the object to update it.
|
||||
account = session.query(Account).get(account_id)
|
||||
account = db.session.query(Account).get(account_id)
|
||||
|
||||
if not account:
|
||||
return None, 404
|
||||
|
||||
session.delete(account)
|
||||
db.session.delete(account)
|
||||
|
||||
return None, 204
|
||||
|
||||
@session_aware
|
||||
@marshal_with_field(Object(resource_fields))
|
||||
def post(self, account_id, session):
|
||||
def post(self, account_id):
|
||||
kwargs = parser.parse_args()
|
||||
|
||||
assert (id not in kwargs or kwargs.id is None
|
||||
or kwargs.id == account_id)
|
||||
|
||||
# Need to get the object to update it.
|
||||
account = session.query(Account).get(account_id)
|
||||
account = db.session.query(Account).get(account_id)
|
||||
|
||||
if not account:
|
||||
return None, 404
|
||||
@ -140,12 +133,10 @@ class AccountResource(Resource):
|
||||
for k, v in kwargs.items():
|
||||
setattr(account, k, v)
|
||||
|
||||
session.merge(account)
|
||||
db.session.merge(account)
|
||||
|
||||
# Return account with solds.
|
||||
return Account.query(
|
||||
session
|
||||
).filter(
|
||||
return Account.query().filter(
|
||||
Account.id == account_id
|
||||
).one()
|
||||
|
||||
|
@ -18,7 +18,7 @@ import dateutil.parser
|
||||
|
||||
from flask.ext.restful import Resource, fields, reqparse, marshal_with_field
|
||||
|
||||
from accountant import session_aware
|
||||
from accountant import db
|
||||
|
||||
from .. import api_api
|
||||
|
||||
@ -62,54 +62,49 @@ range_parser.add_argument('end', type=lambda a: dateutil.parser.parse(a))
|
||||
|
||||
|
||||
class OperationListResource(Resource):
|
||||
@session_aware
|
||||
@marshal_with_field(fields.List(Object(resource_fields)))
|
||||
def get(self, session):
|
||||
def get(self):
|
||||
kwargs = range_parser.parse_args()
|
||||
|
||||
return Operation.query(
|
||||
session,
|
||||
begin=kwargs['begin'],
|
||||
end=kwargs['end'],
|
||||
).filter(
|
||||
Operation.account_id == kwargs['account']
|
||||
).all()
|
||||
|
||||
@session_aware
|
||||
@marshal_with_field(Object(resource_fields))
|
||||
def post(self, session):
|
||||
def post(self):
|
||||
kwargs = parser.parse_args()
|
||||
|
||||
operation = Operation(**kwargs)
|
||||
|
||||
session.add(operation)
|
||||
db.session.add(operation)
|
||||
|
||||
return operation
|
||||
|
||||
|
||||
class OperationResource(Resource):
|
||||
@session_aware
|
||||
@marshal_with_field(Object(resource_fields))
|
||||
def get(self, operation_id, session):
|
||||
def get(self, operation_id):
|
||||
"""
|
||||
Get operation.
|
||||
"""
|
||||
operation = session.query(Operation).get(operation_id)
|
||||
operation = db.session.query(Operation).get(operation_id)
|
||||
|
||||
if not operation:
|
||||
return None, 404
|
||||
|
||||
return operation
|
||||
|
||||
@session_aware
|
||||
@marshal_with_field(Object(resource_fields))
|
||||
def post(self, operation_id, session):
|
||||
def post(self, operation_id):
|
||||
kwargs = parser.parse_args()
|
||||
|
||||
assert (id not in kwargs or kwargs.id is None
|
||||
or kwargs.id == operation_id)
|
||||
|
||||
operation = session.query(Operation).get(operation_id)
|
||||
operation = db.session.query(Operation).get(operation_id)
|
||||
|
||||
if not operation:
|
||||
return None, 404
|
||||
@ -118,19 +113,18 @@ class OperationResource(Resource):
|
||||
for k, v in kwargs.items():
|
||||
setattr(operation, k, v)
|
||||
|
||||
session.merge(operation)
|
||||
db.session.merge(operation)
|
||||
|
||||
return operation
|
||||
|
||||
@session_aware
|
||||
@marshal_with_field(Object(resource_fields))
|
||||
def delete(self, operation_id, session):
|
||||
operation = session.query(Operation).get(operation_id)
|
||||
def delete(self, operation_id):
|
||||
operation = db.session.query(Operation).get(operation_id)
|
||||
|
||||
if not operation:
|
||||
return None, 404
|
||||
|
||||
session.delete(operation)
|
||||
db.session.delete(operation)
|
||||
|
||||
return operation
|
||||
|
||||
@ -143,12 +137,11 @@ category_resource_fields = {
|
||||
|
||||
|
||||
class CategoriesResource(Resource):
|
||||
@session_aware
|
||||
@marshal_with_field(fields.List(Object(category_resource_fields)))
|
||||
def get(self, session):
|
||||
def get(self):
|
||||
kwargs = range_parser.parse_args()
|
||||
|
||||
return Operation.get_categories_for_range(session, **kwargs).all()
|
||||
return Operation.get_categories_for_range(**kwargs).all()
|
||||
|
||||
|
||||
ohlc_resource_fields = {
|
||||
@ -161,12 +154,11 @@ ohlc_resource_fields = {
|
||||
|
||||
|
||||
class SoldsResource(Resource):
|
||||
@session_aware
|
||||
@marshal_with_field(fields.List(Object(ohlc_resource_fields)))
|
||||
def get(self, session):
|
||||
def get(self):
|
||||
kwargs = range_parser.parse_args()
|
||||
|
||||
return Operation.get_ohlc_per_day_for_range(session, **kwargs).all()
|
||||
return Operation.get_ohlc_per_day_for_range(**kwargs).all()
|
||||
|
||||
|
||||
api_api.add_resource(OperationListResource, "/operations")
|
||||
|
@ -21,7 +21,7 @@ from flask.ext.restful import Resource, fields, reqparse, marshal_with_field
|
||||
from sqlalchemy import true
|
||||
from sqlalchemy.orm.exc import NoResultFound
|
||||
|
||||
from accountant import session_aware
|
||||
from accountant import db
|
||||
|
||||
from ..models.scheduled_operations import ScheduledOperation
|
||||
from ..models.operations import Operation
|
||||
@ -60,23 +60,19 @@ get_parser.add_argument('account', type=int)
|
||||
|
||||
|
||||
class ScheduledOperationListResource(Resource):
|
||||
@session_aware
|
||||
@marshal_with_field(fields.List(Object(resource_fields)))
|
||||
def get(self, session):
|
||||
def get(self):
|
||||
"""
|
||||
Get all scheduled operation for the account.
|
||||
"""
|
||||
kwargs = get_parser.parse_args()
|
||||
|
||||
return ScheduledOperation.query(
|
||||
session
|
||||
).filter(
|
||||
return ScheduledOperation.query().filter(
|
||||
ScheduledOperation.account_id == kwargs.account
|
||||
).all()
|
||||
|
||||
@session_aware
|
||||
@marshal_with_field(Object(resource_fields))
|
||||
def post(self, session):
|
||||
def post():
|
||||
"""
|
||||
Add a new scheduled operation.
|
||||
"""
|
||||
@ -84,41 +80,35 @@ class ScheduledOperationListResource(Resource):
|
||||
|
||||
scheduled_operation = ScheduledOperation(**kwargs)
|
||||
|
||||
session.add(scheduled_operation)
|
||||
db.session.add(scheduled_operation)
|
||||
|
||||
scheduled_operation.reschedule(session)
|
||||
scheduled_operation.reschedule()
|
||||
|
||||
session.flush()
|
||||
db.session.flush()
|
||||
|
||||
return scheduled_operation, 201
|
||||
|
||||
|
||||
class ScheduledOperationResource(Resource):
|
||||
@session_aware
|
||||
@marshal_with_field(Object(resource_fields))
|
||||
def get(self, scheduled_operation_id, session):
|
||||
def get(self, scheduled_operation_id):
|
||||
"""
|
||||
Get scheduled operation.
|
||||
"""
|
||||
try:
|
||||
return ScheduledOperation.query(
|
||||
session
|
||||
).filter(
|
||||
return ScheduledOperation.query().filter(
|
||||
ScheduledOperation.id == scheduled_operation_id
|
||||
).one()
|
||||
except NoResultFound:
|
||||
return None, 404
|
||||
|
||||
@session_aware
|
||||
@marshal_with_field(Object(resource_fields))
|
||||
def delete(self, scheduled_operation_id, session):
|
||||
def delete(self, scheduled_operation_id):
|
||||
"""
|
||||
Delete a scheduled operation.
|
||||
"""
|
||||
try:
|
||||
scheduled_operation = ScheduledOperation.query(
|
||||
session
|
||||
).filter(
|
||||
scheduled_operation = ScheduledOperation.query().filter(
|
||||
ScheduledOperation.id == scheduled_operation_id
|
||||
).one()
|
||||
except NoResultFound:
|
||||
@ -137,13 +127,12 @@ class ScheduledOperationResource(Resource):
|
||||
# Delete unconfirmed operations
|
||||
operations = scheduled_operation.operations.delete()
|
||||
|
||||
session.delete(scheduled_operation)
|
||||
db.session.delete(scheduled_operation)
|
||||
|
||||
return None, 204
|
||||
|
||||
@session_aware
|
||||
@marshal_with_field(Object(resource_fields))
|
||||
def post(self, scheduled_operation_id, session):
|
||||
def post(self, scheduled_operation_id):
|
||||
"""
|
||||
Update a scheduled operation.
|
||||
"""
|
||||
@ -153,9 +142,7 @@ class ScheduledOperationResource(Resource):
|
||||
or kwargs.id == scheduled_operation_id)
|
||||
|
||||
try:
|
||||
scheduled_operation = ScheduledOperation.query(
|
||||
session
|
||||
).filter(
|
||||
scheduled_operation = ScheduledOperation.query().filter(
|
||||
ScheduledOperation.id == scheduled_operation_id
|
||||
).one()
|
||||
except NoResultFound:
|
||||
@ -165,11 +152,11 @@ class ScheduledOperationResource(Resource):
|
||||
for k, v in kwargs.items():
|
||||
setattr(scheduled_operation, k, v)
|
||||
|
||||
session.merge(scheduled_operation)
|
||||
db.session.merge(scheduled_operation)
|
||||
|
||||
scheduled_operation.reschedule(session)
|
||||
scheduled_operation.reschedule()
|
||||
|
||||
session.flush()
|
||||
db.session.flush()
|
||||
|
||||
return scheduled_operation
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user