Skip to content

Commit b82ce31

Browse files
author
MotionGPT
committed
Release MotionGPT V1.0
1 parent e77e357 commit b82ce31

165 files changed

Lines changed: 24459 additions & 81 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

README.md

Lines changed: 315 additions & 81 deletions
Large diffs are not rendered by default.

app.py

Lines changed: 365 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
import gradio as gr
2+
import random
3+
import torch
4+
import time
5+
import cv2
6+
import os
7+
import numpy as np
8+
import pytorch_lightning as pl
9+
import moviepy.editor as mp
10+
from pathlib import Path
11+
from mGPT.data.build_data import build_data
12+
from mGPT.models.build_model import build_model
13+
from mGPT.config import parse_args
14+
from scipy.spatial.transform import Rotation as RRR
15+
import mGPT.render.matplot.plot_3d_global as plot_3d
16+
from mGPT.render.pyrender.hybrik_loc2rot import HybrIKJointsToRotmat
17+
from mGPT.render.pyrender.smpl_render import SMPLRender
18+
from transformers import WhisperProcessor, WhisperForConditionalGeneration
19+
import librosa
20+
21+
# Load model
22+
cfg = parse_args(phase="webui") # parse config file
23+
cfg.FOLDER = 'cache'
24+
output_dir = Path(cfg.FOLDER)
25+
output_dir.mkdir(parents=True, exist_ok=True)
26+
pl.seed_everything(cfg.SEED_VALUE)
27+
if cfg.ACCELERATOR == "gpu":
28+
device = torch.device("cuda")
29+
else:
30+
device = torch.device("cpu")
31+
datamodule = build_data(cfg, phase="test")
32+
model = build_model(cfg, datamodule)
33+
state_dict = torch.load(cfg.TEST.CHECKPOINTS, map_location="cpu")["state_dict"]
34+
model.load_state_dict(state_dict)
35+
model.to(device)
36+
37+
audio_processor = WhisperProcessor.from_pretrained(cfg.model.whisper_path)
38+
audio_model = WhisperForConditionalGeneration.from_pretrained(cfg.model.whisper_path).to(device)
39+
forced_decoder_ids = audio_processor.get_decoder_prompt_ids(language="zh", task="translate")
40+
forced_decoder_ids_zh = audio_processor.get_decoder_prompt_ids(language="zh", task="translate")
41+
forced_decoder_ids_en = audio_processor.get_decoder_prompt_ids(language="en", task="translate")
42+
43+
# HTML Style
44+
Video_Components = """
45+
<div class="side-video" style="position: relative;">
46+
<video width="340" autoplay loop>
47+
<source src="file/{video_path}" type="video/mp4">
48+
</video>
49+
<a class="videodl-button" href="file/{video_path}" download="{video_fname}" title="Download Video">
50+
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-video"><path d="m22 8-6 4 6 4V8Z"/><rect width="14" height="12" x="2" y="6" rx="2" ry="2"/></svg>
51+
</a>
52+
<a class="npydl-button" href="file/{motion_path}" download="{motion_fname}" title="Download Motion">
53+
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-person-standing"><circle cx="12" cy="5" r="1"/><path d="m9 20 3-6 3 6"/><path d="m6 8 6 2 6-2"/><path d="M12 10v4"/></svg>
54+
</a>
55+
</div>
56+
"""
57+
58+
Text_Components = """
59+
<h3 class="side-content" >{msg}</h3>
60+
"""
61+
62+
63+
def motion_token_to_string(motion_token, lengths, codebook_size=512):
64+
motion_string = []
65+
for i in range(motion_token.shape[0]):
66+
motion_i = motion_token[i].cpu(
67+
) if motion_token.device.type == 'cuda' else motion_token[i]
68+
motion_list = motion_i.tolist()[:lengths[i]]
69+
motion_string.append(
70+
(f'<motion_id_{codebook_size}>' +
71+
''.join([f'<motion_id_{int(i)}>' for i in motion_list]) +
72+
f'<motion_id_{codebook_size + 1}>'))
73+
return motion_string
74+
75+
76+
def render_motion(data, feats, method='fast'):
77+
fname = time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime(
78+
time.time())) + str(np.random.randint(10000, 99999))
79+
video_fname = fname + '.mp4'
80+
feats_fname = fname + '.npy'
81+
output_npy_path = os.path.join(output_dir, feats_fname)
82+
output_mp4_path = os.path.join(output_dir, video_fname)
83+
np.save(output_npy_path, feats)
84+
85+
if method == 'slow':
86+
if len(data.shape) == 4:
87+
data = data[0]
88+
data = data - data[0, 0]
89+
pose_generator = HybrIKJointsToRotmat()
90+
pose = pose_generator(data)
91+
pose = np.concatenate([
92+
pose,
93+
np.stack([np.stack([np.eye(3)] * pose.shape[0], 0)] * 2, 1)
94+
], 1)
95+
shape = [768, 768]
96+
render = SMPLRender(cfg.RENDER.SMPL_MODEL_PATH)
97+
98+
if not os.environ.get("PYOPENGL_PLATFORM"):
99+
os.environ["DISPLAY"] = ":0.0"
100+
os.environ["PYOPENGL_PLATFORM"] = "egl"
101+
102+
size = (shape[1], shape[0])
103+
fps = 20.0
104+
fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V')
105+
videoWriter = cv2.VideoWriter(output_mp4_path, fourcc, fps, size)
106+
r = RRR.from_rotvec(np.array([np.pi, 0.0, 0.0]))
107+
pose[:, 0] = np.matmul(r.as_matrix().reshape(1, 3, 3), pose[:, 0])
108+
for i in range(data.shape[0]):
109+
img = np.zeros([shape[0], shape[1], 3])
110+
aroot = data[[i], 0] + np.array([[0.0, 0.0, 30.0]])
111+
aroot[:, 1] = -aroot[:, 1]
112+
params = dict(pred_shape=np.zeros([1, 10]),
113+
pred_root=aroot,
114+
pred_pose=pose[[i]])
115+
renderImg = render.render(img.copy(), params)
116+
renderImg = (renderImg * 255).astype(np.uint8)
117+
videoWriter.write(renderImg)
118+
videoWriter.release()
119+
output_video_h264_name = output_mp4_path[:-4] + '_h264.mp4'
120+
command = 'ffmpeg -y -i {} -vcodec h264 {}'.format(
121+
output_mp4_path, output_video_h264_name)
122+
os.system(command)
123+
output_mp4_path = output_video_h264_name
124+
video_fname = video_fname[:-4] + '_h264.mp4'
125+
elif method == 'fast':
126+
output_gif_path = output_mp4_path[:-4] + '.gif'
127+
if len(data.shape) == 3:
128+
data = data[None]
129+
if isinstance(data, torch.Tensor):
130+
data = data.cpu().numpy()
131+
pose_vis = plot_3d.draw_to_batch(data, [''], [output_gif_path])
132+
out_video = mp.VideoFileClip(output_gif_path)
133+
out_video.write_videofile(output_mp4_path)
134+
135+
return output_mp4_path, video_fname, output_npy_path, feats_fname
136+
137+
138+
def load_motion(motion_uploaded, method):
139+
file = motion_uploaded['file']
140+
141+
feats = torch.tensor(np.load(file), device=model.device)
142+
if len(feats.shape) == 2:
143+
feats = feats[None]
144+
# feats = model.datamodule.normalize(feats)
145+
146+
# Motion tokens
147+
motion_lengths = feats.shape[0]
148+
motion_token, _ = model.vae.encode(feats)
149+
150+
motion_token_string = model.lm.motion_token_to_string(
151+
motion_token, [motion_token.shape[1]])[0]
152+
motion_token_length = motion_token.shape[1]
153+
154+
# Motion rendered
155+
joints = model.datamodule.feats2joints(feats.cpu()).cpu().numpy()
156+
output_mp4_path, video_fname, output_npy_path, joints_fname = render_motion(
157+
joints,
158+
feats.to('cpu').numpy(), method)
159+
160+
motion_uploaded.update({
161+
"feats": feats,
162+
"joints": joints,
163+
"motion_video": output_mp4_path,
164+
"motion_video_fname": video_fname,
165+
"motion_joints": output_npy_path,
166+
"motion_joints_fname": joints_fname,
167+
"motion_lengths": motion_lengths,
168+
"motion_token": motion_token,
169+
"motion_token_string": motion_token_string,
170+
"motion_token_length": motion_token_length,
171+
})
172+
173+
return motion_uploaded
174+
175+
176+
def add_text(history, text, motion_uploaded, data_stored, method):
177+
data_stored = data_stored + [{'user_input': text}]
178+
179+
if 'file' in motion_uploaded.keys():
180+
text = Text_Components.format(msg=text)
181+
motion_uploaded = load_motion(motion_uploaded, method)
182+
output_mp4_path = motion_uploaded['motion_video']
183+
video_fname = motion_uploaded['motion_video_fname']
184+
output_npy_path = motion_uploaded['motion_joints']
185+
joints_fname = motion_uploaded['motion_joints_fname']
186+
187+
text = text + Video_Components.format(video_path=output_mp4_path,
188+
video_fname=video_fname,
189+
motion_path=output_npy_path,
190+
motion_fname=joints_fname)
191+
else:
192+
text = f"""<h3>{text}</h3>"""
193+
history = history + [(text, None)]
194+
return history, gr.update(value="",
195+
interactive=False), motion_uploaded, data_stored
196+
197+
198+
def add_audio(history, audio_path, data_stored):
199+
audio, sampling_rate = librosa.load(audio_path, sr=16000)
200+
input_features = audio_processor(
201+
audio, sampling_rate, return_tensors="pt"
202+
).input_features # whisper training sampling rate, do not modify
203+
input_features = torch.Tensor(input_features).to(device)
204+
predicted_ids = audio_model.generate(input_features,
205+
forced_decoder_ids=forced_decoder_ids)
206+
text_input = audio_processor.batch_decode(predicted_ids,
207+
skip_special_tokens=True)
208+
text_input = str(text_input).strip('[]"')
209+
data_stored = data_stored + [{'user_input': text_input}]
210+
gr.update(value=data_stored, interactive=False)
211+
history = history + [(text_input, None)]
212+
213+
return history, data_stored
214+
215+
216+
def add_file(history, file, txt, motion_uploaded):
217+
218+
motion_uploaded['file'] = file.name
219+
txt = txt.replace(" <Motion_Placeholder>", "") + " <Motion_Placeholder>"
220+
return history, gr.update(value=txt, interactive=True), motion_uploaded
221+
222+
223+
def bot(history, motion_uploaded, data_stored, method):
224+
225+
motion_length, motion_token_string = motion_uploaded[
226+
"motion_lengths"], motion_uploaded["motion_token_string"]
227+
228+
input = data_stored[-1]['user_input']
229+
prompt = model.lm.placeholder_fulfill(input, motion_length,
230+
motion_token_string, "")
231+
data_stored[-1]['model_input'] = prompt
232+
batch = {
233+
"length": [motion_length],
234+
"text": [prompt],
235+
}
236+
237+
outputs = model(batch, task="t2m")
238+
out_feats = outputs["feats"][0]
239+
out_lengths = outputs["length"][0]
240+
out_joints = outputs["joints"][:out_lengths].detach().cpu().numpy()
241+
out_texts = outputs["texts"][0]
242+
output_mp4_path, video_fname, output_npy_path, joints_fname = render_motion(
243+
out_joints,
244+
out_feats.to('cpu').numpy(), method)
245+
246+
motion_uploaded = {
247+
"feats": None,
248+
"joints": None,
249+
"motion_video": None,
250+
"motion_lengths": 0,
251+
"motion_token": None,
252+
"motion_token_string": '',
253+
"motion_token_length": 0,
254+
}
255+
256+
data_stored[-1]['model_output'] = {
257+
"feats": out_feats,
258+
"joints": out_joints,
259+
"length": out_lengths,
260+
"texts": out_texts,
261+
"motion_video": output_mp4_path,
262+
"motion_video_fname": video_fname,
263+
"motion_joints": output_npy_path,
264+
"motion_joints_fname": joints_fname,
265+
}
266+
267+
if '<Motion_Placeholder>' == out_texts:
268+
response = [
269+
Video_Components.format(video_path=output_mp4_path,
270+
video_fname=video_fname,
271+
motion_path=output_npy_path,
272+
motion_fname=joints_fname)
273+
]
274+
elif '<Motion_Placeholder>' in out_texts:
275+
response = [
276+
Text_Components.format(
277+
msg=out_texts.split("<Motion_Placeholder>")[0]),
278+
Video_Components.format(video_path=output_mp4_path,
279+
video_fname=video_fname,
280+
motion_path=output_npy_path,
281+
motion_fname=joints_fname),
282+
Text_Components.format(
283+
msg=out_texts.split("<Motion_Placeholder>")[1]),
284+
]
285+
else:
286+
response = f"""<h3>{out_texts}</h3>"""
287+
288+
history[-1][1] = ""
289+
for character in response:
290+
history[-1][1] += character
291+
time.sleep(0.02)
292+
yield history, motion_uploaded, data_stored
293+
294+
295+
with open("assets/css/custom.css", "r", encoding="utf-8") as f:
296+
customCSS = f.read()
297+
298+
with gr.Blocks(css=customCSS) as demo:
299+
300+
# Variables
301+
motion_uploaded = gr.State({
302+
"feats": None,
303+
"joints": None,
304+
"motion_video": None,
305+
"motion_lengths": 0,
306+
"motion_token": None,
307+
"motion_token_string": '',
308+
"motion_token_length": 0,
309+
})
310+
data_stored = gr.State([])
311+
312+
gr.Markdown(
313+
"# Welcome to MotionGPT! \n ## You can type or upload a numpy file contains motion joints."
314+
)
315+
316+
chatbot = gr.Chatbot([], elem_id="mGPT", height=600, label="MotionGPT")
317+
318+
with gr.Row():
319+
with gr.Column(scale=0.85):
320+
txt = gr.Textbox(
321+
show_label=False,
322+
placeholder="Enter text and press enter, or insert motion",
323+
container=False)
324+
with gr.Row():
325+
aud = gr.Audio(label='Speak', source="microphone", type='filepath')
326+
btn = gr.UploadButton("📁 Upload motion",
327+
elem_id="upload",
328+
file_types=["file"],
329+
variant='primary')
330+
regen = gr.Button("🔄 Regenerate", elem_id="regen")
331+
clear = gr.ClearButton([txt, chatbot, aud], value='🗑️ Clear')
332+
333+
with gr.Column(scale=0.15, min_width=150):
334+
method = gr.Dropdown(["slow", "fast"],
335+
label="Render method",
336+
interactive=True,
337+
elem_id="method",
338+
value="fast")
339+
language = gr.Dropdown(["English", "中文"],
340+
label="Speech language",
341+
interactive=True,
342+
elem_id="language",
343+
value="English")
344+
345+
txt_msg = txt.submit(
346+
add_text, [chatbot, txt, motion_uploaded, data_stored, method],
347+
[chatbot, txt, motion_uploaded, data_stored],
348+
queue=False).then(bot, [chatbot, motion_uploaded, data_stored, method],
349+
[chatbot, motion_uploaded, data_stored])
350+
txt_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False)
351+
file_msg = btn.upload(add_file, [chatbot, btn, txt, motion_uploaded],
352+
[chatbot, txt, motion_uploaded],
353+
queue=False)
354+
aud_msg = aud.stop_recording(
355+
add_audio, [chatbot, aud, data_stored], [chatbot, data_stored],
356+
queue=False).then(bot, [chatbot, motion_uploaded, data_stored, method],
357+
[chatbot, motion_uploaded, data_stored])
358+
regen_msg = regen.click(bot,
359+
[chatbot, motion_uploaded, data_stored, method],
360+
[chatbot, motion_uploaded, data_stored])
361+
362+
demo.queue()
363+
364+
if __name__ == "__main__":
365+
demo.launch(server_name="0.0.0.0", server_port=8888, debug=True)

0 commit comments

Comments
 (0)