-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathencoder.py
More file actions
124 lines (87 loc) · 3.33 KB
/
encoder.py
File metadata and controls
124 lines (87 loc) · 3.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
This file was initially copied from https://github.com/denisyarats/pytorch_sac_ae
Changes were made to the following classes/functions:
"""
import torch
import torch.nn as nn
def tie_weights(src, trg):
assert type(src) == type(trg)
trg.weight = src.weight
trg.bias = src.bias
OUT_DIM = {2: 39, 4: 35, 6: 31}
class PixelEncoder(nn.Module):
"""Convolutional encoder of pixels observations."""
def __init__(self, obs_shape, feature_dim, num_layers=2, num_filters=32):
super().__init__()
assert len(obs_shape) == 3
self.feature_dim = feature_dim
self.num_layers = num_layers
self.convs = nn.ModuleList(
[nn.Conv2d(obs_shape[0], num_filters, (3,3), stride=(2,2))]
)
for i in range(num_layers - 1):
self.convs.append(nn.Conv2d(num_filters, num_filters, (3,3), stride=(1,1)))
out_dim = OUT_DIM[num_layers]
self.fc = nn.Linear(num_filters * out_dim * out_dim, self.feature_dim)
self.ln = nn.LayerNorm(self.feature_dim)
self.outputs = dict()
def reparameterize(self, mu, logstd):
std = torch.exp(logstd)
eps = torch.randn_like(std)
return mu + eps * std
def forward_conv(self, obs):
obs = obs / 255.
self.outputs['obs'] = obs
conv = torch.relu(self.convs[0](obs))
self.outputs['conv1'] = conv
for i in range(1, self.num_layers):
conv = torch.relu(self.convs[i](conv))
self.outputs['conv%s' % (i + 1)] = conv
h = conv.view(conv.size(0), -1)
return h
def forward(self, obs, detach=False):
h = self.forward_conv(obs)
if detach:
h = h.detach()
h_fc = self.fc(h)
self.outputs['fc'] = h_fc
h_norm = self.ln(h_fc)
self.outputs['ln'] = h_norm
out = torch.tanh(h_norm)
self.outputs['tanh'] = out
return out
def copy_conv_weights_from(self, source):
"""Tie convolutional layers"""
# only tie conv layers
for i in range(self.num_layers):
tie_weights(src=source.convs[i], trg=self.convs[i])
def log(self, L, step, log_freq):
if step % log_freq != 0:
return
for k, v in self.outputs.items():
L.log_histogram('train_encoder/%s_hist' % k, v, step)
if len(v.shape) > 2:
L.log_image('train_encoder/%s_img' % k, v[0], step)
for i in range(self.num_layers):
L.log_param('train_encoder/conv%s' % (i + 1), self.convs[i], step)
L.log_param('train_encoder/fc', self.fc, step)
L.log_param('train_encoder/ln', self.ln, step)
class IdentityEncoder(nn.Module):
def __init__(self, obs_shape, feature_dim, num_layers, num_filters):
super().__init__()
assert len(obs_shape) == 1
self.feature_dim = obs_shape[0]
def forward(self, obs, detach=False):
return obs
def copy_conv_weights_from(self, source):
pass
def log(self, L, step, log_freq):
pass
_AVAILABLE_ENCODERS = {'pixel': PixelEncoder, 'identity': IdentityEncoder}
def make_encoder(
encoder_type, obs_shape, feature_dim, num_layers, num_filters
):
assert encoder_type in _AVAILABLE_ENCODERS
return _AVAILABLE_ENCODERS[encoder_type](
obs_shape, feature_dim, num_layers, num_filters
)