sqlalchemy-mptt.py 2.36 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
# !/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time    : 2018-09-12 16:32
# @Author  : Jackadam
# @Email   :jackadam@sina.com
# @File    : mptt.py
# @Software: PyCharm
from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy_mptt import mptt_sessionmaker
from sqlalchemy import Column, Integer, String
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_mptt.mixins import BaseNestedSets

Base = declarative_base()


class Tree(Base, BaseNestedSets):
    __tablename__ = "tree"
    id = Column(Integer, primary_key=True)
    name = Column(String(8))

    def __repr__(self):
        return "<Node (%s)>" % self.id


engine = create_engine('sqlite:///mptt.db', echo=False)
mptt_ession = mptt_sessionmaker(sessionmaker(bind=engine))
db_session = scoped_session(sessionmaker(autocommit=False,
                                         autoflush=False,
                                         bind=engine))



def print_tree(group_name, tab=1):
    """
    :param str group_name:要查找的树的根的名称
    :param int tab: 格式化用的-数量
    """
    group = db_session.query(Tree).filter_by(name=group_name).one_or_none()
    if not group:
        return
    # group found - print name and find children
    print('- ' * tab + group.name)
    for child_group in group.children:
        # new tabulation value for child record
        print_tree(child_group.name, tab * 2)

if __name__ == '__main__':
    # Base.metadata.create_all(bind=engine)
    # nodes=[]
    # node=Tree(name='中国')
    # nodes.append(node)
    # db_session.add_all(nodes)
    # db_session.commit()
    # nodes = []
    # ref_id=db_session.query(Tree.id).filter_by(name='中国').first()[0]
    # print(ref_id)
    # new_name=['河南','河北','山东','山西','陕西']
    # for i in new_name:
    #     print(i)
    #     node=Tree(name=i,parent_id=ref_id)
    #     nodes.append(node)
    # db_session.add_all(nodes)
    # db_session.commit()
    # nodes = []
    # ref_id = db_session.query(Tree.id).filter_by(name='河南').first()[0]
    # print(ref_id)
    # new_name = ['郑州', '洛阳', '开封', '新乡', '新郑']
    # for i in new_name:
    #     print(i)
    #     node = Tree(name=i, parent_id=ref_id)
    #     nodes.append(node)
    # db_session.add_all(nodes)
    # db_session.commit()
    print_tree('中国')
    print_tree('河南')