Transformer-DeID: Deidentification of free-text clinical notes with transformers 1.0.0

File: <base>/transformer_deid/load_data.py (6,150 bytes)
import csv
import os
from tqdm import tqdm
from importlib.resources import open_text
import json
from transformer_deid.label import Label
from transformer_deid.data import DeidDataset
from transformer_deid.tokenization import split_sequences, assign_tags, encode_tags


def load_label(filename):
    """
    Loads annotations from a CSV file.
    CSV file should have entity_type/start/stop columns.
    """
    # can change this, make it a variable? options: "base", "hipaa", "binary"
    label_map = get_label_map('base')

    with open(filename, 'r') as fp:
        csvreader = csv.reader(fp, delimiter=',', quotechar='"')
        header = next(csvreader)
        # identify which columns we want
        idx = [
            header.index('entity_type'),
            header.index('start'),
            header.index('stop'),
            header.index('entity')
        ]

        # iterate through the CSV and load in the labels
        if label_map is not None:
            labels = [
                Label(entity_type=label_map[row[idx[0]].upper()],
                      start=int(row[idx[1]]),
                      length=int(row[idx[2]]) - int(row[idx[1]]),
                      entity=row[idx[3]]) for row in csvreader
            ]
        else:
            labels = [
                Label(entity_type=row[idx[0].upper()],
                      start=int(row[idx[1]]),
                      length=int(row[idx[2]]) - int(row[idx[1]]),
                      entity=row[idx[3]]) for row in csvreader
            ]

    return labels


def load_data(path) -> dict:
    """Creates a dict the dataset."""
    examples = {'guid': [], 'text': [], 'ann': []}

    # for deid datasets, "path" is a folder containing txt/ann subfolders
    # "txt" subfolder has text files with the text of the examples
    # "ann" subfolder has annotation files with the labels
    txt_path = path / 'txt'
    ann_path = path / 'ann'
    for f in os.listdir(txt_path):
        if not f.endswith('.txt'):
            continue

        # guid is the file name
        guid = f[:-4]
        with open(txt_path / f, 'r') as fp:
            text = ''.join(fp.readlines())

        # load the annotations from disk
        # these datasets have consistent folder structures:
        #   root_path/txt/RECORD_NAME.txt - has text
        #   root_path/ann/RECORD_NAME.gs - has annotations
        if os.path.isdir(ann_path):
            labels = load_label(ann_path / f'{f[:-4]}.gs')
            examples['ann'].append(labels)

        examples['guid'].append(guid)
        examples['text'].append(text)

    return examples


def create_deid_dataset(data_dict, tokenizer, label2id=None) -> DeidDataset:
    """Creates a dataset from set of texts and labels.

       Args:
            data_dict: dict of text data with 'text', 'ann', and 'guid' keys; e.g., from load_data()
            tokenizer: HuggingFace tokenizer, e.g., loaded from AutoTokenizer.from_pretrained()
            label2id: dict to convert label (e.g., 'O,' 'DATE') to id (e.g., 0, 1)

       Returns: DeidDataset; see class definition in data.py
    """

    # specify dataset arguments
    split_long_sequences = True

    # split text/labels into multiple examples
    # (1) tokenize text
    # (2) identify split points
    # (3) output text as it was originally
    texts = data_dict['text']
    labels = data_dict['ann']
    ids = data_dict['guid']

    if split_long_sequences:
        split_dict = split_sequences(tokenizer, texts, labels, ids=ids)

    texts = split_dict['texts']
    labels = split_dict['labels']
    guids = split_dict['guids']

    encodings = tokenizer(texts,
                          is_split_into_words=False,
                          return_offsets_mapping=True,
                          padding=True,
                          truncation=True)

    if labels != []:
        # use the offset mappings in train_encodings to assign labels to tokens
        tags = assign_tags(encodings, labels)

        # encodings are dicts with three elements:
        #   'input_ids', 'attention_mask', 'offset_mapping'
        # these are used as kwargs to model training later
        tags = encode_tags(tags, encodings, label2id)

    else:
        tags = None

    # prepare a dataset compatible with Trainer module
    encodings.pop("offset_mapping")
    dataset = DeidDataset(encodings, tags, guids)

    return dataset


def get_label_map(transform):
    """ Gets dictionary of labels to convert from specific labels (may differ across datasets) to more general labels. """
    with open_text('transformer_deid', 'label.json') as fp:
        label_map = json.load(fp)

    # label_membership has different label transforms as keys
    if transform not in label_map:
        raise KeyError('Unable to find label transform %s in label.json' %
                       transform)
    label_map = label_map[transform]

    # label_map has items "harmonized_label": ["label 1", "label 2", ...]
    # invert this for the final label mapping
    return {
        label: harmonized_label
        for harmonized_label, original_labels in label_map.items()
        for label in original_labels
    }


def get_labels(labels_each_doc):
    """Gets the list of labels for this data set."""
    unique_labels = list(
        set(label.entity_type for labels in labels_each_doc
            for label in labels))

    unique_labels.sort()

    # add in the object tag - ensure it is the first element for label2id and id2label dict
    if 'O' in unique_labels:
        unique_labels.pop('O')
    unique_labels = ['O'] + unique_labels

    return unique_labels


def save_labels(labels, ids, out_path):
    if not os.path.exists(out_path):
        os.mkdir(out_path)
    header = ['start', 'stop', 'entity', 'entity_type']

    for doc, id in tqdm(zip(labels, ids), total=len(ids)):
        label_list = [[
            label.start, label.start + label.length, label.entity,
            label.entity_type
        ] for label in doc]
        with open(f'{out_path}/{id}.gs', 'w') as f:
            writer = csv.writer(f)
            writer.writerow(header)
            writer.writerows(label_list)