app.py 3.62 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
# -*- coding: utf-8 -*-
import string
import flask_restful
from flask import Flask, abort, jsonify
from flask_jwt_extended import ( JWTManager )
from flask_sqlalchemy import SQLAlchemy
from flask_marshmallow import Marshmallow
from hashids import Hashids
from webcreator.response import ResponseCode, response_result
from webcreator.log import logger
from .config import config

# 初始化app
app = Flask(__name__)
# 初始化sqlalchemy
app.config.from_object(config)
db = SQLAlchemy(app)
# 初始化marshmallow
ma = Marshmallow(app)
# 增加jwt校验
jwt = JWTManager(app)

hash_ids = Hashids(salt='hvwptlmj027d5quf', min_length=8, alphabet=string.ascii_lowercase + string.digits) # hash函数

# 保留flask原生异常处理
handle_exception = app.handle_exception
handle_user_exception = app.handle_user_exception

# 过期令牌
@jwt.expired_token_loader
def expired_token_callback(jwt_header, jwt_payload):
    logger.info(jwt_payload)
    return jsonify({
        'code': 4011,
        'msg': 'token expired',
        'data': jwt_header
    })

# 无效令牌
@jwt.invalid_token_loader
def invalid_token_callback(error):  # we have to keep the argument here, since it's passed in by the caller internally
    return jsonify({
        'code': 4012,
        'msg': 'invalid token',
        'data': error
    })

# 校验失败
@jwt.unauthorized_loader
def unauthorized_callback(error):
    return jsonify({
        'code': 4013,
        'msg': 'unauthorized',
        'data': error
    })

def _custom_abort(http_status_code, **kwargs):
    """
    自定义abort 400响应数据格式
    """
    if http_status_code == 400:
        message = kwargs.get('message')
        if isinstance(message, dict):
            param, info = list(message.items())[0]
            data = '{}:{}!'.format(param, info)
wanli's avatar
wanli committed
66
            return abort(jsonify(response_result(ResponseCode.PARAMETER_ERROR, data=data)))
67
        else:
wanli's avatar
wanli committed
68
            return abort(jsonify(response_result(ResponseCode.PARAMETER_ERROR, data=message)))
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
    # return { 'code': http_status_code, 'msg': kwargs.get('message') }
    return abort(http_status_code)

def _access_control(response):
    """
    解决跨域请求
    """
    # response.headers['Access-Control-Allow-Origin'] = '*'
    # response.headers['Access-Control-Allow-Methods'] = 'GET,HEAD,PUT,PATCH,POST,DELETE'
    # response.headers['Access-Control-Allow-Headers'] = 'Content-Type'
    # response.headers['Access-Control-Max-Age'] = 86400

    response.headers['Access-Control-Allow-Origin'] = '*'
    response.headers['Access-Control-Allow-Methods'] = 'GET, HEAD, PUT, PATCH, POST, DELETE, OPTIONS'
    response.headers['Access-Control-Allow-Headers'] = 'Origin, No-Cache, X-Requested-With, If-Modified-Since, Pragma, Last-Modified, Cache-Control, Expires, Content-Type, X-E4M-With, Authorization'
    response.headers['Access-Control-Expose-Headers'] = 'Authorization'
    response.headers['Access-Control-Max-Age'] = 86400
    response.headers['Access-Control-Request-Headers'] = 'Origin, X-Requested-With, Content-Type, Accept, Authorization'

    return response

def create_app(config):
    """
    创建app
    """
    # 添加配置
    app.config.from_object(config)
    # 解决跨域
    app.after_request(_access_control)
    # 自定义abort 400 响应数据格式
    flask_restful.abort = _custom_abort
    # 数据库初始化
wanli's avatar
wanli committed
101
    # db.app = app
102
    db.init_app(app)
wanli's avatar
wanli committed
103 104
    # 创建表
    db.create_all()
105 106 107 108 109 110 111
    # 注册蓝图
    from views import api_v1
    app.register_blueprint(api_v1, url_prefix='/api/v1')
    # 使用flask原生异常处理程序
    app.handle_exception = handle_exception
    app.handle_user_exception = handle_user_exception
    return app