dictで保持している特徴量のtrain_test_split
やりたいこと
以下のように辞書型で保持している特徴量を
{'field1': array([0, 1, 2, 3, 4, 5]), 'field2': array([5, 4, 3, 2, 1, 0]), 'label': array([1, 0, 1, 0, 0, 0])}
以下のように辞書型を保持したまま分割したい
{'field1': array([4, 0, 3]), 'field2': array([1, 5, 2]), 'label': array([0, 1, 0])} {'field1': array([5, 1, 2]), 'field2': array([0, 4, 3]), 'label': array([0, 0, 1])}
やり方
クラス分布に考慮した抽出などのオプションを自分で実装するのは結構しんどいので、scikit-learnのmodel_selection.train_test_splitをラップした処理を作成する ランダムなindexの配列を作成し、train_test_splitでlabelの配列と共に分割し、各arrayからindex指定で抽出すればできた
class Splitter: def __init__(self, train_size, label_col: str): self.train_size = train_size self.label_col = label_col self.train_indices = None self.test_indices = None def set_split_indices(self, field_to_values): total_length = len(field_to_values[self.label_col]) split_indices = np.array(range(total_length)) labels = field_to_values[self.label_col] self.train_indices, self.test_indices, _, _ = train_test_split( split_indices, labels, train_size=self.train_size,stratify=labels) def split(self, field_to_values): train = {field: values[self.train_indices] for field, values in field_to_values.items()} test = {field: values[self.test_indices] for field, values in field_to_values.items()} return train, test
実行結果
>>> field_to_values = {"field1": np.array([0, 1, 2, 3, 4, 5]), "field2": np.array([5, 4, 3, 2, 1, 0]), "label": np.array([1, 0, 1, 0, 0, 0])} >>> splitter = Splitter(train_size=0.5, label_col="label") >>> splitter.set_split_indices(field_to_values) >>> splitter.split(field_to_values) ({'field1': array([4, 0, 3]), 'field2': array([1, 5, 2]), 'label': array([0, 1, 0])}, {'field1': array([5, 1, 2]), 'field2': array([0, 4, 3]), 'label': array([0, 0, 1])})
今回複数辞書を同じindexで分割したかったのでクラス化したが、関数として書くと以下になる
def train_test_split_dict(field_to_values, train_size, label_col: str) total_length = len(field_to_values[label_col]) split_indices = np.array(range(total_length)) labels = field_to_values[label_col] train_indices, test_indices, _, _ = train_test_split( split_indices, labels, train_size=train_size, stratify=labels) train = {field: values[train_indices] for field, values in field_to_values.items()} test = {field: values[test_indices] for field, values in field_to_values.items()} return train, test