Pythonで不規則な2次元標準リストをflattenする

やりたいこと

[0, 0, [0, 0], 0, [0]]

みたいな不規則にlistが含まれる標準リストを以下のように平坦化したい

[0, 0, 0, 0, 0, 0]

やり方

一度全ての値をリスト化して、その上でitertools.chain.from_iterableを適用すればできた

def flatten_sequences(sequences: List[list]) -> list:
    sequences = [i if type(i) == list else [i] for i in sequences]
    flattened = list(itertools.chain.from_iterable(sequences))
    return flattened

実行結果

>>> flatten_sequences([0, 0, [0, 0], 0, [0]])
[0, 0, 0, 0, 0, 0]