tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

gn_ast.py (11208B)


      1 # Lint as: python3
      2 # Copyright 2023 The Chromium Authors
      3 # Use of this source code is governed by a BSD-style license that can be
      4 # found in the LICENSE file.
      5 """Helper script to use GN's JSON interface to make changes.
      6 
      7 AST implementation details:
      8  https://gn.googlesource.com/gn/+/refs/heads/main/src/gn/parse_tree.cc
      9 
     10 To dump an AST:
     11  gn format --dump-tree=json BUILD.gn > foo.json
     12 """
     13 
     14 from __future__ import annotations
     15 
     16 import dataclasses
     17 import functools
     18 import json
     19 import subprocess
     20 from typing import Callable, Dict, List, Optional, Tuple, TypeVar
     21 
     22 NODE_CHILD = 'child'
     23 NODE_TYPE = 'type'
     24 NODE_VALUE = 'value'
     25 
     26 _T = TypeVar('_T')
     27 
     28 
     29 def _create_location_node(begin_line=1):
     30    return {
     31        'begin_column': 1,
     32        'begin_line': begin_line,
     33        'end_column': 2,
     34        'end_line': begin_line,
     35    }
     36 
     37 
     38 def _wrap(node: dict):
     39    kind = node[NODE_TYPE]
     40    if kind == 'LIST':
     41        return StringList(node)
     42    if kind == 'BLOCK':
     43        return BlockWrapper(node)
     44    return NodeWrapper(node)
     45 
     46 
     47 def _unwrap(thing):
     48    if isinstance(thing, NodeWrapper):
     49        return thing.node
     50    return thing
     51 
     52 
     53 def _find_node(root_node: dict, target_node: dict):
     54    def recurse(node: dict) -> Optional[Tuple[dict, int]]:
     55        children = node.get(NODE_CHILD)
     56        if children:
     57            for i, child in enumerate(children):
     58                if child is target_node:
     59                    return node, i
     60                ret = recurse(child)
     61                if ret is not None:
     62                    return ret
     63        return None
     64 
     65    ret = recurse(root_node)
     66    if ret is None:
     67        raise Exception(
     68            f'Node not found: {target_node}\nLooked in: {root_node}')
     69    return ret
     70 
     71 @dataclasses.dataclass
     72 class NodeWrapper:
     73    """Base class for all wrappers."""
     74    node: dict
     75 
     76    @property
     77    def node_type(self) -> str:
     78        return self.node[NODE_TYPE]
     79 
     80    @property
     81    def node_value(self) -> str:
     82        return self.node[NODE_VALUE]
     83 
     84    @property
     85    def node_children(self) -> List[dict]:
     86        return self.node[NODE_CHILD]
     87 
     88    @functools.cached_property
     89    def first_child(self):
     90        return _wrap(self.node_children[0])
     91 
     92    @functools.cached_property
     93    def second_child(self):
     94        return _wrap(self.node_children[1])
     95 
     96    def is_list(self):
     97        return self.node_type == 'LIST'
     98 
     99    def is_identifier(self):
    100        return self.node_type == 'IDENTIFIER'
    101 
    102    def visit_nodes(self, callback: Callable[[dict],
    103                                             Optional[_T]]) -> List[_T]:
    104        ret = []
    105 
    106        def recurse(root: dict):
    107            value = callback(root)
    108            if value is not None:
    109                ret.append(value)
    110                return
    111            children = root.get(NODE_CHILD)
    112            if children:
    113                for child in children:
    114                    recurse(child)
    115 
    116        recurse(self.node)
    117        return ret
    118 
    119    def set_location_recursive(self, line):
    120        def helper(n: dict):
    121            loc = n.get('location')
    122            if loc:
    123                loc['begin_line'] = line
    124                loc['end_line'] = line
    125 
    126        self.visit_nodes(helper)
    127 
    128    def add_child(self, node, *, before=None):
    129        node = _unwrap(node)
    130        if before is None:
    131            self.node_children.append(node)
    132        else:
    133            before = _unwrap(before)
    134            parent_node, child_idx = _find_node(self.node, before)
    135            parent_node[NODE_CHILD].insert(child_idx, node)
    136 
    137            # Prevent blank lines between |before| and |node|.
    138            target_line = before['location']['begin_line']
    139            _wrap(node).set_location_recursive(target_line)
    140 
    141    def remove_child(self, node):
    142        node = _unwrap(node)
    143        parent_node, child_idx = _find_node(self.node, node)
    144        parent_node[NODE_CHILD].pop(child_idx)
    145 
    146 
    147 @dataclasses.dataclass
    148 class BlockWrapper(NodeWrapper):
    149    """Wraps a BLOCK node."""
    150    def __post_init__(self):
    151        assert self.node_type == 'BLOCK'
    152 
    153    def find_assignments(self, var_name=None):
    154        def match_fn(node: dict):
    155            assignment = AssignmentWrapper.from_node(node)
    156            if not assignment:
    157                return None
    158            if var_name is None or var_name == assignment.variable_name:
    159                return assignment
    160            return None
    161 
    162        return self.visit_nodes(match_fn)
    163 
    164 
    165 @dataclasses.dataclass
    166 class AssignmentWrapper(NodeWrapper):
    167    """Wraps a =, +=, or -= BINARY node where the LHS is an identifier."""
    168    def __post_init__(self):
    169        assert self.node_type == 'BINARY'
    170 
    171    @property
    172    def variable_name(self):
    173        return self.first_child.node_value
    174 
    175    @property
    176    def value(self):
    177        return self.second_child
    178 
    179    @property
    180    def list_value(self):
    181        ret = self.second_child
    182        assert isinstance(ret, StringList), 'Found: ' + ret.node_type
    183        return ret
    184 
    185    @property
    186    def operation(self):
    187        """The assignment operation. Either "=" or "+="."""
    188        return self.node_value
    189 
    190    @property
    191    def is_append(self):
    192        return self.operation == '+='
    193 
    194    def value_as_string_list(self):
    195        return StringList(self.value.node)
    196 
    197    @staticmethod
    198    def from_node(node: dict) -> Optional[AssignmentWrapper]:
    199        if node.get(NODE_TYPE) != 'BINARY':
    200            return None
    201        children = node[NODE_CHILD]
    202        assert len(children) == 2, (
    203            'Binary nodes should have two child nodes, but the node is: '
    204            f'{node}')
    205        left_child, right_child = children
    206        if left_child.get(NODE_TYPE) != 'IDENTIFIER':
    207            return None
    208        if node.get(NODE_VALUE) not in ('=', '+=', '-='):
    209            return None
    210        return AssignmentWrapper(node)
    211 
    212    @staticmethod
    213    def create(variable_name, value, operation='='):
    214        value_node = _unwrap(value)
    215        id_node = {
    216            'location': _create_location_node(),
    217            'type': 'IDENTIFIER',
    218            'value': variable_name,
    219        }
    220        return AssignmentWrapper({
    221            'location': _create_location_node(),
    222            'child': [id_node, value_node],
    223            'type': 'BINARY',
    224            'value': operation,
    225        })
    226 
    227    @staticmethod
    228    def create_list(variable_name, operation='='):
    229        return AssignmentWrapper.create(variable_name,
    230                                        StringList.create(),
    231                                        operation=operation)
    232 
    233 
    234 @dataclasses.dataclass
    235 class StringList(NodeWrapper):
    236    """Wraps a list node that contains only string literals."""
    237    def __post_init__(self):
    238        assert self.is_list()
    239 
    240        self.literals: List[str] = [
    241            x[NODE_VALUE].strip('"') for x in self.node_children
    242            if x[NODE_TYPE] == 'LITERAL'
    243        ]
    244 
    245    def add_literal(self, value: str):
    246        # For lists of deps, gn format will sort entries, but it will not
    247        # move entries past comment boundaries. Insert at the front by default
    248        # so that if sorting moves the value, and there is a comment boundary,
    249        # it will end up before the comment instead of immediately after the
    250        # comment (which likely does not apply to it).
    251        self.literals.insert(0, value)
    252        self.node_children.insert(
    253            0, {
    254                'location': _create_location_node(),
    255                'type': 'LITERAL',
    256                'value': f'"{value}"',
    257            })
    258 
    259    def remove_literal(self, value: str):
    260        self.literals.remove(value)
    261        quoted = f'"{value}"'
    262        children = self.node_children
    263        for i, node in enumerate(children):
    264            if node[NODE_VALUE] == quoted:
    265                children.pop(i)
    266                break
    267        else:
    268            raise ValueError(f'Did not find child with value {quoted}')
    269 
    270    @staticmethod
    271    def create() -> StringList:
    272        return StringList({
    273            'location': _create_location_node(),
    274            'begin_token': '[',
    275            'child': [],
    276            'end': {
    277                'location': _create_location_node(),
    278                'type': 'END',
    279                'value': ']'
    280            },
    281            'type': 'LIST',
    282        })
    283 
    284 
    285 class Target(NodeWrapper):
    286    """Wraps a target node.
    287 
    288    A target node is any function besides "template" with exactly two children:
    289      * Child 1: LIST with single string literal child
    290      * Child 2: BLOCK
    291 
    292    This does not actually find all targets. E.g. ignores those that use an
    293    expression for a name, or that use "target(type, name)".
    294    """
    295    def __init__(self, function_node: dict, name_node: dict):
    296        super().__init__(function_node)
    297        self.name_node = name_node
    298 
    299    @property
    300    def name(self) -> str:
    301        return self.name_node[NODE_VALUE].strip('"')
    302 
    303    # E.g. "android_library"
    304    @property
    305    def type(self) -> str:
    306        return self.node[NODE_VALUE]
    307 
    308    @property
    309    def block(self) -> BlockWrapper:
    310        block = self.second_child
    311        assert isinstance(block, BlockWrapper)
    312        return block
    313 
    314    def set_name(self, value):
    315        self.name_node[NODE_VALUE] = f'"{value}"'
    316 
    317    @staticmethod
    318    def from_node(node: dict) -> Optional[Target]:
    319        """Returns a Target if |node| is a target, None otherwise."""
    320        if node.get(NODE_TYPE) != 'FUNCTION':
    321            return None
    322        if node.get(NODE_VALUE) == 'template':
    323            return None
    324        children = node.get(NODE_CHILD)
    325        if not children or len(children) != 2:
    326            return None
    327        func_params_node, block_node = children
    328        if block_node.get(NODE_TYPE) != 'BLOCK':
    329            return None
    330        if func_params_node.get(NODE_TYPE) != 'LIST':
    331            return None
    332        param_nodes = func_params_node.get(NODE_CHILD)
    333        if param_nodes is None or len(param_nodes) != 1:
    334            return None
    335        name_node = param_nodes[0]
    336        if name_node.get(NODE_TYPE) != 'LITERAL':
    337            return None
    338        return Target(function_node=node, name_node=name_node)
    339 
    340 
    341 class BuildFile:
    342    """Represents the contents of a BUILD.gn file."""
    343    def __init__(self, path: str, root_node: dict):
    344        self.block = BlockWrapper(root_node)
    345        self.path = path
    346        self._original_content = json.dumps(root_node)
    347 
    348    def write_changes(self) -> bool:
    349        """Returns whether there were any changes."""
    350        new_content = json.dumps(self.block.node)
    351        if new_content == self._original_content:
    352            return False
    353        output = subprocess.check_output(
    354            ['gn', 'format', '--read-tree=json', self.path],
    355            text=True,
    356            input=new_content)
    357        if 'Wrote rebuilt from json to' not in output:
    358            raise Exception('JSON was invalid')
    359        return True
    360 
    361    @functools.cached_property
    362    def targets(self) -> List[Target]:
    363        return self.block.visit_nodes(Target.from_node)
    364 
    365    @functools.cached_property
    366    def targets_by_name(self) -> Dict[str, Target]:
    367        return {t.name: t for t in self.targets}
    368 
    369    @staticmethod
    370    def from_file(path):
    371        output = subprocess.check_output(
    372            ['gn', 'format', '--dump-tree=json', path], text=True)
    373        return BuildFile(path, json.loads(output))