Transformer-DeID: Deidentification of free-text clinical notes with transformers 1.0.0

File: <base>/transformer_deid/annotate.py (3,650 bytes)
import torch
import argparse
import os
import numpy as np
from pathlib import Path
import logging
from tqdm import tqdm
from huggingface_hub import login
from transformer_deid.data import DeidDataset
from transformer_deid.tokenization import merge_sequences, encodings_to_label_list
from transformers import AutoModelForTokenClassification, AutoTokenizer
from transformer_deid.load_data import load_data, create_deid_dataset, save_labels
from transformer_deid.train import which_transformer_arch

logger = logging.getLogger(__name__)


def annotate(model, test_dataset: DeidDataset, device):
    """Annotates dataset with PHI labels.

       Args:
            model: trained HuggingFace model
            test_dataset: DeidDataset, see data.py
            device: either cuda or cpu

       Returns:
            new_labels: list of lists of [entity type, start, length] objects for each document
    """
    model = model.to(device)
    model.eval()

    logger.info(
        f'Running predictions over {len(test_dataset.encodings.encodings)} split sequences from {len(test_dataset.ids)} documents.'
    )
    result = [
        get_logits(encoding, model, device)
        for encoding in tqdm(test_dataset.encodings.encodings,
                             total=len(test_dataset.encodings.encodings))
    ]
    predicted_label = np.argmax(result, axis=2)

    id2label = model.config.id2label
    predicted_label = [[id2label[token] for token in doc]
                       for doc in predicted_label]

    labels = [encodings_to_label_list(doc, test_dataset.encodings[i], id2label=id2label) for i, doc in enumerate(predicted_label)]

    new_labels = merge_sequences(labels, test_dataset.ids)

    return new_labels


def get_logits(encodings, model, device):
    """ Return predicted labels from the encodings of a *single* text example. """
    result = model(input_ids=torch.tensor([encodings.ids]).to(device),
                   attention_mask=torch.tensor([encodings.attention_mask
                                                ]).to(device))
    logits = result['logits'].cpu().detach().numpy()
    return logits[0]


def parse_args():
    parser = argparse.ArgumentParser(
        description='Train a transformer-based PHI deidentification model.')

    parser.add_argument('-i',
                        '--test_path',
                        type=str,
                        help='string to diretory containing txt directory for the test set.')

    parser.add_argument('-m', '--model_path', type=str, help='model directory')

    parser.add_argument('-o',
                        '--output_path',
                        help='output path in which to optionally save the annotations',
                        default=None)

    args = parser.parse_args()

    return args


def main(args):
    train_path = args.test_path
    out_path = args.output_path
    modelDir = args.model_path
    baseArchitecture = os.path.basename(modelDir).split('-')[0].lower()
    _, tokenizer, _ = which_transformer_arch(baseArchitecture)

    data_dict = load_data(Path(train_path))
    test_dataset = create_deid_dataset(data_dict, tokenizer)

    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    logger.info(f'Running using {device}.')

    model = AutoModelForTokenClassification.from_pretrained(modelDir)

    annotations = annotate(model, test_dataset, device)

    if out_path is not None:
        logger.info(f'Saving annotations to {out_path}.')
        save_labels(annotations, data_dict['guid'], out_path)

    else:
        return annotations, test_dataset.ids


if __name__ == '__main__':
    args = parse_args()
    main(args)