163 lines
4.2 KiB
Python
163 lines
4.2 KiB
Python
"""
|
|
This file is part of Accountant.
|
|
|
|
Accountant is free software: you can redistribute it and/or modify
|
|
it under the terms of the GNU Affero General Public License as published by
|
|
the Free Software Foundation, either version 3 of the License, or
|
|
(at your option) any later version.
|
|
|
|
Accountant is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU Affero General Public License for more details.
|
|
|
|
You should have received a copy of the GNU Affero General Public License
|
|
along with Accountant. If not, see <http://www.gnu.org/licenses/>.
|
|
"""
|
|
# vim: set tw=80 ts=4 sw=4 sts=4:
|
|
from functools import wraps
|
|
|
|
import arrow
|
|
|
|
from flask import request, g, current_app as app
|
|
from flask_restplus import Namespace, Resource, fields, marshal_with_field
|
|
|
|
from ..fields import Object
|
|
|
|
from ..models.users import User
|
|
|
|
|
|
def load_user_from_token(token):
|
|
"""Load user from token."""
|
|
|
|
return User.verify_auth_token(token)
|
|
|
|
|
|
def load_user_from_auth(auth):
|
|
"""Load a user from authenticated session."""
|
|
|
|
token = auth.replace('Bearer ', '', 1)
|
|
return load_user_from_token(token)
|
|
|
|
|
|
def requires_auth(f):
|
|
@wraps(f)
|
|
"""Decorator to check user authentication before handling a request."""
|
|
|
|
def wrapped(*args, **data):
|
|
"""Check user authentication from requests headers and return 401 if
|
|
unauthorized."""
|
|
|
|
user = None
|
|
|
|
if 'Authorization' in request.headers:
|
|
auth = request.headers['Authorization']
|
|
user = load_user_from_auth(auth)
|
|
|
|
if user:
|
|
g.user = user
|
|
return f(*args, **data)
|
|
|
|
ns.abort(
|
|
401,
|
|
error_message='Please login before executing this request.'
|
|
)
|
|
return wrapped
|
|
|
|
|
|
# pylint: disable=invalid-name
|
|
ns = Namespace('user', description='User management')
|
|
|
|
# Token with expiration time and type.
|
|
token_model = ns.model('Token', {
|
|
'token': fields.String(
|
|
required=True,
|
|
readonly=True,
|
|
description='Token value'),
|
|
'expiration': fields.DateTime(
|
|
dt_format='iso8601',
|
|
required=True,
|
|
readonly=True,
|
|
description='Expiration time of the token'),
|
|
'token_type': fields.String(
|
|
required=True,
|
|
readonly=True,
|
|
description='Token type')
|
|
})
|
|
|
|
# User model.
|
|
user_model = ns.model('User', {
|
|
'id': fields.Integer(
|
|
default=None,
|
|
required=True,
|
|
readonly=True,
|
|
description='Id of the user'),
|
|
'email': fields.String(
|
|
required=True,
|
|
readonly=True,
|
|
decription='Email address of the user'),
|
|
'active': fields.Boolean(
|
|
required=True,
|
|
readonly=True,
|
|
description='Active state of the user')
|
|
})
|
|
|
|
# Login model.
|
|
login_model = ns.model('Login', {
|
|
'email': fields.String(
|
|
required=True,
|
|
description='Email to use for login'
|
|
),
|
|
'password': fields.String(
|
|
required=True,
|
|
description='Plain text password to use for login'
|
|
)
|
|
})
|
|
|
|
|
|
@ns.route('/login')
|
|
class LoginResource(Resource):
|
|
"""Resource to handle login operations."""
|
|
|
|
@ns.marshal_with(token_model)
|
|
@ns.doc(
|
|
responses={
|
|
200: ('OK', token_model),
|
|
401: 'Unauthorized'
|
|
})
|
|
@ns.expect(login_model)
|
|
def post(self):
|
|
"""Login to retrieve authentication token."""
|
|
|
|
data = self.api.payload
|
|
|
|
user = User.query().filter(
|
|
User.email == data['email']
|
|
).one_or_none()
|
|
|
|
if not user or not user.verify_password(data['password']):
|
|
ns.abort(401, error_message="Bad user or password.")
|
|
|
|
token = user.generate_auth_token()
|
|
expiration_time = arrow.now().replace(
|
|
seconds=app.config['SESSION_TTL']
|
|
)
|
|
|
|
return {
|
|
'token': token,
|
|
'expiration': expiration_time.datetime,
|
|
'token_type': 'Bearer'
|
|
}, 200
|
|
|
|
@requires_auth
|
|
@ns.doc(
|
|
security='apikey',
|
|
responses={
|
|
200: ('OK', user_model)
|
|
})
|
|
@marshal_with_field(Object(user_model))
|
|
def get(self):
|
|
"""Get authenticated user information."""
|
|
|
|
return g.user, 200
|