datasets库使用

Huggingface的datasets类,保护一些场景数据集和metric

load_dataset

1
2
3
from datasets import load_dataset

data=load_dataset('dataset_name','sub_name',split='train..')

得到Dataset类的实例

num_rows查看行数

1
2
dataset['train'].num_rows
#3668

下面包括features类和数据

1
2
3
4
5
dataset['train'].features
#{'idx': Value(dtype='int32', id=None),
# 'label': ClassLabel(num_classes=2, names=['not_equivalent', #'equivalent'], id=None),
# 'sentence1': Value(dtype='string', id=None),
# 'sentence2': Value(dtype='string', id=None)}
1
2
3
4
5
6
7
8
9
10
11
#查看数据类型
dataset['train'].features['idx']
#Value(dtype='int32', id=None)

dataset['train'].features['label']
#ClassLabel(num_classes=2, names=['not_equivalent', 'equivalent'], id=None)

dataset['train'].features['sentence1']
#Value(dtype='string', id=None)

dataset['train'].features['label'].names
  • set_format() 设置喂给模型的数据,类型
1
set_format(type='torch',columns=['idx','labels'])
Local and remote files

加载本地文件csv,text,jsonlines,json

1
2
3
4
5
from datasets import load_dataset

dataset=load_dataset('csv',data_files='my_file.csv')

data=load_dataset('json',data_files='my_file.json',field='data')

加载remote file,csv ,zipped.csv

1
2
3
base_url="https://huggingface.co/datasets/lhoestq/demo1/resolve/main/data/"

dataset=load_dataset('csv'.data_files={'train':base_url+'train.csv'})

load_dataset_builder

在download之前,查看数据集的dataset card

1
2
3
4
5
6
7
8
from datasets import load_dataset_builder
dataset_builder=load_dataset_builder('imdb')
#查看cache地址
dataset_builder.cache_dir
#查看info.features
dataset_builder.info.features
#查看split
dataset_builder.info.splits

get_dataset_config_names

获取subset的名字

1
2
3
4
from datasets import get_dataset_config_names

config=get_dataset_config_names('glue')
config

Dataset

  • dataset.info
  • dataset.split / description/ citation/homepage
  • dataset.shape
  • dataset.num_columns
  • dataset.num_rows
  • dataset.column_names
  • dataset.features
load memory中数据
  • Dataset.from_dict()
  • Dataset.from_pandas()
处理数据
  • dataset.sort(‘label’)
  • dataset.shuffle(seed=42)
  • dataset.select([0,10,20,…])
  • dataset.filter(lambda example:example[‘sentence1’.startswith(‘A’)])
  • dataset.train_test_split(test_size=0.1) 默认shuffle,可以设置shuffle=False
  • dataset.shard(num_shards=4,index=0) 划分为4份数据,选择0份
数据columns处理
  • dataset.rename_column(‘sentenc1’,’s1’)

  • dataset.remove_column(‘label’)

    修改column数据类型

  • dataset.cast(new_features) new_features=data.features.copy()

  • dataset.cast_column(‘audio’,Audio(sampling_rate=16000))

  • dataset.flatten() 展平嵌套结构数据

Map
1
data=dataset.map(lambda example:{'new':example['s1']},remove_columns=['s1'])
连接数据集

columns名相同则可以连接

1
2
3
from datasets import concatenate_datasets

data=concatenate_datasets([bookcorpus,wiki])
save
  • dataset.to_csv()
  • dataset.to_json()
  • dataset.to_pandas()
  • dataset.to_dict()

list_metrics

查看全部的metrics

1
2
3
4
5
from datasets import list_metrics
metrics_list=list_metrics
metrics_list

['accuracy', 'bertscore', 'bleu', 'bleurt', 'cer', 'comet', 'coval', 'cuad', 'f1', 'gleu', 'glue', 'indic_glue', 'matthews_correlation', 'meteor', 'pearsonr', 'precision', 'recall', 'rouge', 'sacrebleu', 'sari', 'seqeval', 'spearmanr', 'squad', 'squad_v2', 'super_glue', 'wer', 'wiki_split', 'xnli']

load_metric

加载metric进行评测

1
2
3
4
5
6
7
8
9
10
11
12
from datasets import load_metric
metric = load_metric('glue', 'mrpc')

#查看输入输出
metric.inputs_description

#使用
result=metric.compute(prediction,reference)

#batch
metric.add_batch(prediction,reference)
metric.compute()
api