Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

execution time breakdown #51

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 88 additions & 27 deletions train/compute/pt/pytorch_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def train_cpu(
loss_f = nn.CrossEntropyLoss()

# model.train()
start_time = time.time()
events = {"start_all": [], "stop_all": [], "start_fwd": [], "stop_fwd": [], "start_bwd": [], "stop_bwd": [], "start_opt": [], "stop_opt":[]}
times = {"all": 0, "fwd": 0, "bwd": 0, "opt": 0}
events["start_all"].append(time.time())

for i in range(args.steps + args.warmups):
data = torch.randn(batch_size, input_size, device=device)
Expand All @@ -45,15 +47,26 @@ def train_cpu(
if data_type == "float16":
data = data.half()

events["start_all"].append(time.time())
optimizer.zero_grad()
events["start_fwd"].append(time.time())
output = model(data).float()
events["stop_fwd"].append(time.time())
loss = loss_f(output, target)
events["start_bwd"].append(time.time())
loss.backward()
events["stop_bwd"].append(time.time())
events["start_opt"].append(time.time())
optimizer.step()
events["stop_opt"].append(time.time())
events["stop_all"].append(time.time())
if i < args.warmups:
start_time = time.time()
for t in events.values():
t.clear()

return time.time() - start_time, loss
for key in ["all", "fwd", "bwd", "opt"]:
times[key] = sum([te-ts for ts, te in zip(events["start_"+key], events["stop_"+key])])
return times, loss


def train_gpu(
Expand All @@ -67,10 +80,11 @@ def train_gpu(
model = apex.fp16_utils.network_to_half(model)

# model.train()
times = {"all": 0, "fwd": 0, "bwd": 0, "opt": 0}
events = {"start_all": [], "stop_all": [], "start_fwd": [], "stop_fwd": [], "start_bwd": [], "stop_bwd": [], "start_opt": [], "stop_opt":[]}
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
total_time = 0.0
for e in events.keys():
events[e] = torch.cuda.Event(enable_timing=True)

for i in range(args.steps + args.warmups):
data = torch.randn(batch_size, input_size, device=device)
Expand All @@ -81,20 +95,26 @@ def train_gpu(
if data_type == "float16":
data = data.half()

if i >= args.warmups:
start_event.record()
events["start_all"].record()

optimizer.zero_grad()
events["start_fwd"].record()
output = model(data).float()
events["stop_fwd"].record()
loss = loss_f(output, target)
events["start_bwd"].record()
loss.backward()
events["stop_bwd"].record()
events["start_opt"].record()
optimizer.step()
events["stop_opt"].record()
if i >= args.warmups:
end_event.record()
events["stop_all"].record()
torch.cuda.synchronize()
total_time += start_event.elapsed_time(end_event) * 1.0e-3
for key in ["all", "fwd", "bwd", "opt"]:
times[key] += events["start_"+key].elapsed_time(events["stop_"+key]) * 1.0e-3

return total_time, loss
return times, loss


def train_tpu(
Expand All @@ -105,7 +125,9 @@ def train_tpu(
loss_f = nn.CrossEntropyLoss().to(device)

# model.train()
start_time = time.time()
times = {"all": 0, "fwd": 0, "bwd": 0, "opt": 0}
events = {"start_all": [], "stop_all": [], "start_fwd": [], "stop_fwd": [], "start_bwd": [], "stop_bwd": [], "start_opt": [], "stop_opt":[]}
events["start_all"].append(time.time())

for i in range(args.steps + args.warmups):
data = torch.randn(batch_size, input_size, device=device)
Expand All @@ -114,24 +136,35 @@ def train_tpu(
)
# data, target = data.to(device), target.to(device)

events["start_all"].append(time.time())
optimizer.zero_grad()
events["start_fwd"].append(time.time())
output = model(data).float()
events["stop_fwd"].append(time.time())
loss = loss_f(output, target)
events["start_bwd"].append(time.time())
loss.backward()
events["stop_bwd"].append(time.time())
events["start_opt"].append(time.time())
optimizer.step()
xm.mark_step()
events["stop_opt"].append(time.time())
events["stop_all"].append(time.time())
if i < args.warmups:
start_time = time.time()
for t in events.values():
t.clear()

return time.time() - start_time, loss
for key in ["all", "fwd", "bwd", "opt"]:
times[key] = sum([te-ts for ts, te in zip(events["start_"+key], events["stop_"+key])])
return times, loss


def train(
model, device, optimizer, data_type, input_size, output_size, batch_size, args
):

if device.type == "cpu":
elap, loss = train_cpu(
times, loss = train_cpu(
model,
device,
optimizer,
Expand All @@ -143,7 +176,7 @@ def train(
)

elif device.type == "cuda":
elap, loss = train_gpu(
times, loss = train_gpu(
model,
device,
optimizer,
Expand All @@ -155,7 +188,7 @@ def train(
)

elif device.type == "xla":
elap, loss = train_tpu(
times, loss = train_tpu(
model,
device,
optimizer,
Expand All @@ -166,7 +199,7 @@ def train(
args,
)

return elap, loss
return times, loss


def run_single(args, layers_size, batch_size):
Expand Down Expand Up @@ -203,6 +236,14 @@ def run_single(args, layers_size, batch_size):
optimizer = apex.optimizers.FusedLAMB(
model.parameters(), lr=lr, set_grad_none=True
)
elif optimizer_type == "adam":
optimizer = apex.optimizers.FusedAdam(
model.parameters(), lr=lr, set_grad_none=True
)
elif optimizer_type == "adagrad":
optimizer = apex.optimizers.FusedAdagrad(
model.parameters(), lr=lr, set_grad_none=True
)
else:
assert 0, "Unsupported optimizer type"

Expand All @@ -225,10 +266,10 @@ def run_single(args, layers_size, batch_size):
else:
assert 0, "Unsupported optimizer type"

elap, loss = train(
times, loss = train(
model, dev, optimizer, data_type, layers_size[0], layers_size[-1], batch_size, args
)
return elap, loss
return times, loss


def run(args, dataset):
Expand All @@ -237,41 +278,61 @@ def run(args, dataset):
"--------------------------------------------------------------------------------"
)
print(
" #Layer Input Hidden Output Batch Time(s)/step QPS Rate(TF/s)"
" Num Layers Batch Time(s)/step: All FWD BWD OPT QPS (TF/s):FWD BWD OPT(GB/s)"
)
print(
"--------------------------------------------------------------------------------"
)

for i in range(len(dataset)):
layers_size, batch_size = dataset[i]
elap, loss = run_single(
times, loss = run_single(
args, layers_size, batch_size
)
elap = times["all"]
fwd_t = times["fwd"]
bwd_t = times["bwd"]
opt_t = times["opt"]

elap /= args.steps
fwd_t /= args.steps
bwd_t /= args.steps
opt_t /= args.steps

flops = 0
for i in range(len(layers_size)-1):
flops += layers_size[i] * layers_size[i+1]
params = flops
bytes_per_dtype = 4 if args.dtype == "float" else 2
params *= bytes_per_dtype
# how many bytes of optimizer states?
flops *= batch_size

# Forward 2x and Backward 4x
fwd_flops = flops * 2
bwd_flops = flops * 6
flops *= 6



QPS = batch_size / elap

# The hidden layer size could vary, but for now keeping for backward
# compatibility
print(
"{0:6}, {1:6}, {2:6}, {3:6}, {4:6}, {5:10.6f}, {6:8.1f}, {7:10.1f}".format(
"{0:6}, {1:6}, {2:.3f}, {3:.3f}, {4:.3f}, {5:.3f}, {6:8.1f}, {7:10.1f}, {8:.3f}, {9:.3f}, {10:.3f}".format(
len(layers_size),
layers_size[0],
layers_size[1],
layers_size[-1],
batch_size,
elap,
fwd_t,
bwd_t,
opt_t,
QPS,
flops / elap / 1.0e12,
fwd_flops / fwd_t / 1.0e12,
bwd_flops / bwd_t / 1.0e12,
params / opt_t / 1.0e9

)
)

Expand Down Expand Up @@ -304,7 +365,7 @@ def dash_separated_ints(value):
"--optimizer-type",
default="sgd",
help="Optimizer: SGD",
choices=["sgd", "lamb"],
choices=["sgd", "lamb", "adam", "adagrad"],
)
parser.add_argument(
"--dtype",
Expand Down