美文网首页
高效 Python 代码——类与继承

高效 Python 代码——类与继承

作者: rollingstarky | 来源:发表于2019-12-14 01:46 被阅读0次

    一、用辅助类(而不是字典)来维护程序的状态

    Python 内置的字典类型可以很好地保存某个对象在其生命周期中的(动态)内部状态。
    如下面的成绩单类:

    class SimpleGradebook(object):
        def __init__(self):
            self._grades = {}
    
        def add_student(self, name):
            self._grades[name] = []
    
        def report_grade(self, name, score):
            self._grades[name].append(score)
    
        def average_grade(self, name):
            grades = self._grades[name]
            return sum(grades) / len(grades)
    
    
    book = SimpleGradebook()
    book.add_student('Isaac Newton')
    book.report_grade('Isaac Newton', 90)
    book.report_grade('Isaac Newton', 80)
    
    print(book.average_grade('Isaac Newton'))
    # => 85.0
    

    在上面的 SimpleGradebook 类中,学生名字及其对应的成绩都保存在 _grades 字典结构中,这样就不必把每个学生都表示成对象并预设一个用于存放名字的属性了。

    字典类型用起来方便,但也容易因为过度使用导致一些问题。如果需要扩充上面成绩单类的功能,把学生成绩按照科目保存。则 _grades 字典中需要嵌入另一个字典存储科目与多次成绩的键值对。
    即类似这样的结构:{'Einstein': {'Math': [80, 90]}}

    class BySubjectGradebook(object):
        def __init__(self):
            self._grades = {}
    
        def add_student(self, name):
            self._grades[name] = {}
    
        def report_grade(self, name, subject, grade):
            by_subject = self._grades[name]
            grade_list = by_subject.setdefault(subject, [])
            grade_list.append(grade)
    
        def average_grade(self, name):
            by_subject = self._grades[name]
            total, count = 0, 0
            for grades in by_subject.values():
                total += sum(grades)
                count += len(grades)
            return total / count
    
    
    book = BySubjectGradebook()
    book.add_student('Albert Einstein')
    book.report_grade('Albert Einstein', 'Math', 80)
    book.report_grade('Albert Einstein', 'Math', 90)
    book.report_grade('Albert Einstein', 'Gym', 70)
    book.report_grade('Albert Einstein', 'Gym', 80)
    
    print(book.average_grade('Albert Einstein'))
    # => 80.0
    

    假设需求再次改变,在记录某个分数的同时,还需要记录该次成绩占该科目历次成绩的权重。。。此时用于保存成绩的数据结构可以改成这样:
    {'Einstein': {'Math': [(80, 0.4), (90, 0.6)]}}
    但是对于新的 average_grade 方法来说,处理上述数据记录的代码就比较难以理解了。

    把嵌套结构重构为类

    用来保存程序状态的数据结构一旦过于复杂(如包含多层嵌套),则应该将其拆解为类,提供更为明确的接口,同时更好的封装数据。

    from collections import namedtuple
    
    Grade = namedtuple('Grade', ('score', 'weight'))
    
    # 科目类。_grades 属性用于保存带权重的分数(Grade())对象
    # average_grade 方法用于按权重计算本科成绩的平均分
    class Subject(object):
        def __init__(self):
            self._grades = []
    
        def report_grade(self, score, weight):
            self._grades.append(Grade(score, weight))
    
        def average_grade(self):
            total, total_weight = 0, 0
            for grade in self._grades:
                total += grade.score * grade.weight
                total_weight += grade.weight
            return total / total_weight
    
    # 学生类。_subjects 属性用于保存该学生的所有科目(Subject())对象
    # average_grade 方法用于计算该学生所有科目成绩的平均分
    class Student(object):
        def __init__(self):
            self._subjects = {}
    
        def subject(self, name):
            if name not in self._subjects:
                self._subjects[name] = Subject()
            return self._subjects[name]
    
        def average_grade(self):
            total, count = 0, 0
            for subject in self._subjects.values():
                total += subject.average_grade()
                count += 1
            return total / count
    
    # 成绩单类。_students 属性保存所有的学生(Student())对象
    class Gradebook(object):
        def __init__(self):
            self._students = {}
    
        def student(self, name):
            if name not in self._students:
                self._students[name] = Student()
            return self._students[name]
    
    
    book = Gradebook()
    albert = book.student('Albert Einstein')
    math = albert.subject('Math')
    math.report_grade(80, 0.4)
    math.report_grade(90, 0.6)
    
    print(albert.average_grade())
    # => 86.0
    
    要点
    • 尽量不使用多层嵌套的字典(如包含其他字典的字典)存储程序状态,也不要使用过长的元组
    • 容器中包含简单又不可变的数据,可以先使用 namedtuple 表示,有需要时再改为完整的类
    • 用于保存程序内部状态的字典若过于复杂,应将其拆解为多个辅助类

    二、用 super 初始化父类

    初始化父类的传统方式,是在子类里直接调用父类的 __init__ 方法:

    class MyBaseClass(object):
        def __init__(self, value):
            self.value = value
    
    class MyChildClass(MyBaseClass):
        def __init__(self):
            MyBaseClass.__init__(self, 5)
    

    上述方法对于简单的继承行为是可行的,但是很多情况下仍会出现问题。
    首先若子类受到多重继承的影响,则直接调用父类的 __init__ 方法会产生无法预知的行为(调用顺序不确定)。

    class MyBaseClass(object):
        def __init__(self, value):
            self.value = value
    
    class TimesTwo(object):
        def __init__(self):
            self.value *= 2
    
    class PlusFive(object):
        def __init__(self):
            self.value += 5
    
    class OneWay(MyBaseClass, TimesTwo, PlusFive):
        def __init__(self, value):
            MyBaseClass.__init__(self, value)
            TimesTwo.__init__(self)
            PlusFive.__init__(self)
    
    class AnotherWay(MyBaseClass, TimesTwo, PlusFive):
        def __init__(self, value):
            MyBaseClass.__init__(self, value)
            PlusFive.__init__(self)
            TimesTwo.__init__(self)
    
    foo = OneWay(5)
    print(f"First ordering is (5 * 2) + 5 = {foo.value}")
    
    bar = AnotherWay(5)
    print(f"Second ordering still is {foo.value}")
    

    OneWayAnotherWay 中定义了两种完全不同的调用父类 __init__ 方法的顺序,但实际执行的结果却是相同的,导致子类代码中定义的调用顺序与子类实际产生的行为不一致。

    此外在菱形继承中,直接调用父类的构造器也会出现问题。菱形继承是指子类继承自两个单独的超类,而这两个父类又都继承自同一个公共基类。这种继承方式会使菱形顶部的公共基类多次执行其 __init__ 方法,产生意想不到的行为。

    class MyBaseClass(object):
        def __init__(self, value):
            self.value = value
    
    class TimesFive(MyBaseClass):
        def __init__(self, value):
            MyBaseClass.__init__(self, value)
            self.value *= 5
    
    class PlusTwo(MyBaseClass):
        def __init__(self, value):
            MyBaseClass.__init__(self, value)
            self.value += 2
    
    class ThisWay(TimesFive, PlusTwo):
        def __init__(self, value):
            TimesFive.__init__(self, value)
            PlusTwo.__init__(self, value)
    
    foo = ThisWay(5)
    print(f"Should be (5 * 5) + 2 = 27 but actully is {foo.value}")
    # => Should be (5 * 5) + 2 = 27 but actully is 7
    

    最终结果为 7,原因是在调用第二个父类的构造器(即 PlusTwo.__init__)时,公共基类的构造器(即 MyBaseClass.__init__)会再次被调用导致 value 重新变成 5,而不能保持 TimesFive.__init__ 之后的 25。

    通过 super 调用父类的构造器:
    class MyBaseClass(object):
        def __init__(self, value):
            self.value = value
    
    class TimesFive(MyBaseClass):
        def __init__(self, value):
            super(__class__, self).__init__(value)
            self.value *= 5
    
    class PlusTwo(MyBaseClass):
        def __init__(self, value):
            super(__class__, self).__init__(value)
            self.value += 2
    
    class ThisWay(TimesFive, PlusTwo):
        def __init__(self, value):
            super(__class__, self).__init__(value)
    
    foo = ThisWay(5)
    print(f"Should be (5 * 5) + 2 = 27 but actully is {foo.value}")
    # => Should be (5 * 5) + 2 = 27 but actully is 35
    

    此时程序的行为可以说和预期相符合了,注意后三个类中 super 的使用。其中传入 super 的第一个参数 __class__ 用来指代当前类本身。

    三、只在使用 Mix-in 制作工具类时进行多重继承

    应在 Python 编程中尽量避开多重继承。
    若一定要利用多重继承带来的便捷及封装性,应考虑编写 mix-in 类。
    mix-in 是一种“小型”类,其中只定义了其他类可能需要提供的一套附加方法,但是不定义自身的实例属性,也不要求继承者调用自己的 __init__ 构造器。

    可以在 mix-in 里面通过动态检测机制编写一套通用的功能代码,根据对各类对象当前状态的判定,确定代码实际的行为。从而将 mix-in 应用到多个不同的类上面。
    分层地组合 mix-in 类可以减少重复代码并提升代码复用程度。
    如需要将内存中的 Python 对象转换为字典结构(即通常所说的序列化操作),可以创建下面的 mix-in 类实现此功能并添加 public 方法供其他类继承。重点在于通过 hasattr 动态地访问实例的属性、用 isinstance 动态地检测对象类型、用 __dict__ 访问实例内部的字典。

    # todict.py
    class ToDictMixin(object):
        def to_dict(self):
            return self._traverse_dict(self.__dict__)
    
        def _traverse_dict(self, instance_dict):
            output = {}
            for key, value in instance_dict.items():
                output[key] = self._traverse(key, value)
            return output
    
        def _traverse(self, key, value):
            if isinstance(value, ToDictMixin):
                return value.to_dict()
            elif isinstance(value, dict):
                return self._traverse_dict(value)
            elif isinstance(value, list):
                return [self._traverse(key, i) for i in value]
            elif hasattr(value, '__dict__'):
                return self._traverse_dict(value.__dict__)
            else:
                return value
    
    # 利用 mix-in 将二叉树表示为字典
    class BinaryTree(ToDictMixin):
        def __init__(self, value, left=None, right=None):
            self.value = value
            self.left = left
            self.right = right
    
    
    if __name__ == '__main__':
        tree = BinaryTree(10,
                          left=BinaryTree(7, right=BinaryTree(9)),
                          right=BinaryTree(13, left=BinaryTree(11)))
        print(tree.to_dict())
    # => {'value': 10, 'left': {'value': 7, 'left': None, 'right': {'value': 9, 'left': None, 'right': None}}, 'right': {'value': 13, 'left': {'value': 11, 'left': None, 'right': None}, 'right': None}}
    

    mix-in 最大的优势在于,可以随时向基类中添加额外的通用功能,并且在必要时覆盖重写某些方法。
    多个 mix-in 之间也可以相互组合。
    在前面 todict.py 代码的基础上,可以再定义一个 JsonMixin 用来为任意类提供通用的 JSON 序列化功能。JsonMixin 的定义代码决定了继承自它的类需要包含 to_dict 方法(比如可以从 ToDictMixin 中继承),且其 __init__ 方法接受关键字参数:

    import json
    from todict import ToDictMixin
    
    class JsonMixin(object):
        @classmethod
        def from_json(cls, data):
            kwargs = json.loads(data)
            return cls(**kwargs)
    
        def to_json(self):
            return json.dumps(self.to_dict())
    
    class DatacenterRack(ToDictMixin, JsonMixin):
        def __init__(self, switch=None, machines=None):
            self.switch = Switch(**switch)
            self.machines = [
                Machine(**kwargs) for kwargs in machines
            ]
    
    class Switch(ToDictMixin, JsonMixin):
        def __init__(self, ports=None, speed=None):
            self.ports = ports
            self.speed = speed
    
    class Machine(ToDictMixin, JsonMixin):
        def __init__(self, cores=None, ram=None, disk=None):
            self.cores = cores
            self.ram = ram
            self.disk = disk
    
    serialized = """{
        "switch": {"ports": 5, "speed": 1e9},
        "machines": [
            {"cores": 8, "ram": 32e9, "disk": 5e12},
            {"cores": 4, "ram": 16e9, "disk": 1e12},
            {"cores": 2, "ram": 4e9, "disk": 500e9}
        ]
    }"""
    
    deserialized = DatacenterRack.from_json(serialized)
    roundtrip = deserialized.to_json()
    assert json.loads(serialized) == json.loads(roundtrip)
    

    ToDictMixinToJsonMixin 两个 mix-in 中分别定义了不同的通用功能组件,符合规范的多重继承下的子类则可以直接使用这两者提供的 to_dictto_json 方法,达到功能整合的效果。

    要点
    • 能用 mix-in 组件实现的效果,就不用多重继承来做
    • 将各功能实现为可插拔的 mix-in 组件,让子类选择继承需要的组件
    • 简单行为封装到 mix-in 组件里,多个 mix-in 组合成复杂行为

    四、多使用 public 属性

    Python 类的属性有 publicprivate 两种,任何人都可以在对象上通过 dot 操作符(.)访问 public 属性。
    private 属性是名称中以两个下划线开头的属性,可以被当前类中的方法访问。但从类外部直接访问 private 属性会报 AttributeError 异常。

    class MyObject(object):
        def __init__(self):
            self.public_field = 5
            self.__private_field = 10
    
        def get_private_field(self):
            return self.__private_field
    
    
    foo = MyObject()
    print(foo.public_field)
    # => 5
    print(foo.get_private_field())
    # => 10
    print(foo.__private_field)
    # => AttributeError: 'MyObject' object has no attribute '__private_field'
    

    类方法可以访问当前类的私有属性。子类无法访问父类的私有字段

    class MyObject(object):
        def __init__(self):
            self.__private_field = 71
    
        @classmethod
        def get_private_field(cls, instance):
            return instance.__private_field
    
    
    class MyChildObject(MyObject):
        def get_private_field(self):
            return self.__private_field
    
    
    bar = MyObject()
    print(MyObject.get_private_field(bar))
    # => 71
    bar_child = MyChildObject()
    print(bar_child.get_private_field())
    # => AttributeError: 'MyChildObject' object has no attribute '_MyChildObject__private_field'
    

    Python 会对私有属性的名称做一些简单的变换,这种变换导致了私有属性对类外部不可见,同时子类也无法访问父类的私有属性。也就是说,Python 并没有从语法上严格保证 private 字段的私密性。
    在子类的继承体系发生变化时,对 private 字段的引用很容易失效,从而导致子类出现错误。

    为了尽量减少无意义的访问内部属性导致的意外,Python 习惯用单下划线开头的字段表示受保护(protected)字段,当前类之外的代码使用这些字段时应格外注意。
    一般来说,宁可让子类更多地访问父类的 protected 属性,也尽量不要把这些属性设置成 private。并且应该在文档中说明每个 protected 字段的含义,在扩展代码时如何保证数据安全。

    class MyClass(object):
        def __init__(self, value):
            # This stores the user-supplied value for the object.
            # It should be coercible to a string. Once assigned for
            # the object it should be treated as immutable.
            self._value = value
    
    要点
    • Python 编译器无法严格保证 private 字段的私密性
    • 应多使用 protected 属性,并将合理用法在文档中说明
    • 只有子类不受自己控制时,为避免命名冲突才考虑使用 private 属性

    五、继承 collections.abc 以实现自定义的容器类型

    大部分的 Python 编程都是在定义类,类可以包含数据且能够描述数据与对象之间的交互方式。Python 中的类从某种程度上说都是封装了属性与功能容器。

    如果要设计功能比较简单的序列,可以直接继承 Python 内置的 list 类型。如创建一种可以统计各元素出现频率的自定义列表:

    class FrequencyList(list):
        def __init__(self, members):
            super().__init__(members)
    
        def frequency(self):
            counts = {}
            for item in self:
                counts.setdefault(item, 0)
                counts[item] += 1
            return counts
    
    foo = FrequencyList(['a', 'b', 'a', 'c', 'b', 'a', 'd'])
    print(f'Length is {len(foo)}')
    print(f'Frequency: ', foo.frequency())
    # => Length is 7
    # => Frequency:  {'a': 3, 'b': 2, 'c': 1, 'd': 1}
    

    假设某对象本身并不是 list 类型的子类,但是需要它表现得像 list 一样,可以通过下标访问其元素。
    Python 用一些名称比较特殊的实例方法来实现与容器有关的行为。如需要用下标访问序列中的元素,可以考虑实现序列对象的 __getitem__ 方法:

    >>> bar = [1, 2, 3]
    >>> bar[0]
    1
    >>> bar.__getitem__(0)
    1
    

    如下面的二叉树类实现了 __getitem__ 方法,使得不仅可以按深度优先的次序遍历(_traverse())二叉树中的对象,还可以通过下标访问:

    # binarynode.py
    class BinaryNode(object):
        def __init__(self, value, left=None, right=None):
            self.value = value
            self.left = left
            self.right = right
    
    
    class IndexableNode(BinaryNode):
        def _traverse(self):
            if self.left is not None:
                yield from self.left._traverse()
            yield self
            if self.right is not None:
                yield from self.right._traverse()
    
        def __getitem__(self, index):
            for i, item in enumerate(self._traverse()):
                if i == index:
                    return item.value
            raise IndexError(f'Index {index} is out of range')
    
    
    class SequenceNode(IndexableNode):
        def __len__(self):
            for count, _ in enumerate(self._traverse(), 1):
                pass
            return count
    
    
    if __name__ == '__main__':
        tree = SequenceNode(
            10,
            left=SequenceNode(
                5,
                left=SequenceNode(2),
                right=SequenceNode(
                    6,
                    right=SequenceNode(7))),
            right=SequenceNode(
                15,
                left=SequenceNode(11)))
    
        print('LRR is', tree.left.right.right.value)
        # => LRR is 7
        print('Index 0 is', tree[0])
        # => Index 0 is 2
        print('11 in the tree?', 11 in tree)
        # => 11 in the tree? True
        print('Tree is', list(tree))
        # => Tree is [2, 5, 6, 7, 10, 11, 15]
        print('Tree length is', len(tree))
        # => Tree length is 7
    

    为了使序列可以通过内置的 len() 函数获取长度,SequenceNode 类中还实现了 __len__ 方法。而更多的功能就意味着需要额外实现更多的特殊方法。
    为了避免这些麻烦,可以使用内置的 collections.abc 模块。此模块定义了一系列抽象基类,提供了每一种容器类型所应具备的常用方法。继承这样的基类,如果忘记实现某个方法,collections.abc 模块会报错;如果实现了抽象基类要求的每一个方法,则基类会自动实现剩下的所有方法。

    from binarynode import SequenceNode
    from collections.abc import Sequence
    
    class BetterNode(SequenceNode, Sequence):
        pass
    
    tree = BetterNode(
        10,
        left=BetterNode(
            5,
            left=BetterNode(2),
            right=BetterNode(
                6,
                right=BetterNode(7))),
        right=BetterNode(
            15,
            left=BetterNode(11))
    )
    
    print('Index of 7 is', tree.index(7))
    # => Index of 7 is 3
    print('Count of 10 is', tree.count(10))
    # => Count of 10 is 1
    

    SequenceNode 中已经实现了 Sequence 要求的 __getitem____len__ 方法,因此 Sequence 基类为继承的 BetterNode 子类自动实现了 index()count() 方法。

    要点
    • 如自定义的容器子类比较简单,可直接继承 Python 内置的容器类型(如 list、dict 等)
    • 正确实现自定义容器类型可能需要编写大量的特殊方法
    • 编写自定义容器类型时,可以从 collections.abc 模块中的抽象基类继承,这些基类能保证子类具备统一的接口及行为

    参考资料

    Effective Python

    相关文章

      网友评论

          本文标题:高效 Python 代码——类与继承

          本文链接:https://www.haomeiwen.com/subject/uvpogctx.html