defdfs(i, pre): data = {labels[i]: 1} for nxt in edge_map[i]: if nxt != pre: for k, v in dfs(nxt, i).items(): if k in data: data[k] += v else: data[k] = v ans[i] = data[labels[i]] return data
dfs(0, None) return ans
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
classSolution: defcountSubTrees(self, n: int, edges: List[List[int]], labels: str) -> \ List[int]: edge_map = defaultdict(list) for e in edges: edge_map[e[0]].append(e[1]) edge_map[e[1]].append(e[0]) ans = [1] * n
defdfs(i, pre): data = Counter({labels[i]: 1}) for nxt in edge_map[i]: if nxt != pre: data += dfs(nxt, i) ans[i] = data[labels[i]] return data