main
parent
d1c214070e
commit
5475e98007
|
@ -0,0 +1,38 @@
|
|||
import collections
|
||||
|
||||
import datasets
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import nltk
|
||||
nltk.download('punkt')
|
||||
nltk.download('averaged_perceptron_tagger')
|
||||
import tqdm
|
||||
|
||||
seed = 1234
|
||||
validate_size = 0.25
|
||||
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
train_data, test_data = datasets.load_dataset("imdb", split=["train", "test"])
|
||||
|
||||
print(nltk.pos_tag(nltk.word_tokenize("Hello there you stupid fucking whore mr parker")))
|
||||
|
||||
def tokenize(input):
|
||||
return {"tokens": nltk.word_tokenize(input["text"])}
|
||||
|
||||
train_data = train_data.map(tokenize)
|
||||
test_data = test_data.map(tokenize)
|
||||
|
||||
train_valid_data = train_data.train_test_split(test_size=validate_size)
|
||||
train_data = train_valid_data["train"]
|
||||
valid_data = train_valid_data["test"]
|
||||
|
||||
print(train_data)
|
||||
print(test_data)
|
||||
print(valid_data)
|
Loading…
Reference in New Issue