rlm@1: // 7zAes.cpp rlm@1: rlm@1: #include "StdAfx.h" rlm@1: rlm@1: extern "C" rlm@1: { rlm@1: #include "../../../C/Sha256.h" rlm@1: } rlm@1: rlm@1: #include "Windows/Synchronization.h" rlm@1: #include "../Common/StreamObjects.h" rlm@1: #include "../Common/StreamUtils.h" rlm@1: #include "7zAes.h" rlm@1: #include "MyAes.h" rlm@1: rlm@1: #ifndef EXTRACT_ONLY rlm@1: #include "RandGen.h" rlm@1: #endif rlm@1: rlm@1: using namespace NWindows; rlm@1: rlm@1: namespace NCrypto { rlm@1: namespace NSevenZ { rlm@1: rlm@1: bool CKeyInfo::IsEqualTo(const CKeyInfo &a) const rlm@1: { rlm@1: if (SaltSize != a.SaltSize || NumCyclesPower != a.NumCyclesPower) rlm@1: return false; rlm@1: for (UInt32 i = 0; i < SaltSize; i++) rlm@1: if (Salt[i] != a.Salt[i]) rlm@1: return false; rlm@1: return (Password == a.Password); rlm@1: } rlm@1: rlm@1: void CKeyInfo::CalculateDigest() rlm@1: { rlm@1: if (NumCyclesPower == 0x3F) rlm@1: { rlm@1: UInt32 pos; rlm@1: for (pos = 0; pos < SaltSize; pos++) rlm@1: Key[pos] = Salt[pos]; rlm@1: for (UInt32 i = 0; i < Password.GetCapacity() && pos < kKeySize; i++) rlm@1: Key[pos++] = Password[i]; rlm@1: for (; pos < kKeySize; pos++) rlm@1: Key[pos] = 0; rlm@1: } rlm@1: else rlm@1: { rlm@1: CSha256 sha; rlm@1: Sha256_Init(&sha); rlm@1: const UInt64 numRounds = UInt64(1) << (NumCyclesPower); rlm@1: Byte temp[8] = { 0,0,0,0,0,0,0,0 }; rlm@1: for (UInt64 round = 0; round < numRounds; round++) rlm@1: { rlm@1: Sha256_Update(&sha, Salt, (size_t)SaltSize); rlm@1: Sha256_Update(&sha, Password, Password.GetCapacity()); rlm@1: Sha256_Update(&sha, temp, 8); rlm@1: for (int i = 0; i < 8; i++) rlm@1: if (++(temp[i]) != 0) rlm@1: break; rlm@1: } rlm@1: Sha256_Final(&sha, Key); rlm@1: } rlm@1: } rlm@1: rlm@1: bool CKeyInfoCache::Find(CKeyInfo &key) rlm@1: { rlm@1: for (int i = 0; i < Keys.Size(); i++) rlm@1: { rlm@1: const CKeyInfo &cached = Keys[i]; rlm@1: if (key.IsEqualTo(cached)) rlm@1: { rlm@1: for (int j = 0; j < kKeySize; j++) rlm@1: key.Key[j] = cached.Key[j]; rlm@1: if (i != 0) rlm@1: { rlm@1: Keys.Insert(0, cached); rlm@1: Keys.Delete(i+1); rlm@1: } rlm@1: return true; rlm@1: } rlm@1: } rlm@1: return false; rlm@1: } rlm@1: rlm@1: void CKeyInfoCache::Add(CKeyInfo &key) rlm@1: { rlm@1: if (Find(key)) rlm@1: return; rlm@1: if (Keys.Size() >= Size) rlm@1: Keys.DeleteBack(); rlm@1: Keys.Insert(0, key); rlm@1: } rlm@1: rlm@1: static CKeyInfoCache g_GlobalKeyCache(32); rlm@1: static NSynchronization::CCriticalSection g_GlobalKeyCacheCriticalSection; rlm@1: rlm@1: CBase::CBase(): rlm@1: _cachedKeys(16), rlm@1: _ivSize(0) rlm@1: { rlm@1: for (int i = 0; i < sizeof(_iv); i++) rlm@1: _iv[i] = 0; rlm@1: } rlm@1: rlm@1: void CBase::CalculateDigest() rlm@1: { rlm@1: NSynchronization::CCriticalSectionLock lock(g_GlobalKeyCacheCriticalSection); rlm@1: if (_cachedKeys.Find(_key)) rlm@1: g_GlobalKeyCache.Add(_key); rlm@1: else rlm@1: { rlm@1: if (!g_GlobalKeyCache.Find(_key)) rlm@1: { rlm@1: _key.CalculateDigest(); rlm@1: g_GlobalKeyCache.Add(_key); rlm@1: } rlm@1: _cachedKeys.Add(_key); rlm@1: } rlm@1: } rlm@1: rlm@1: #ifndef EXTRACT_ONLY rlm@1: rlm@1: /* rlm@1: STDMETHODIMP CEncoder::ResetSalt() rlm@1: { rlm@1: _key.SaltSize = 4; rlm@1: g_RandomGenerator.Generate(_key.Salt, _key.SaltSize); rlm@1: return S_OK; rlm@1: } rlm@1: */ rlm@1: rlm@1: STDMETHODIMP CEncoder::ResetInitVector() rlm@1: { rlm@1: _ivSize = 8; rlm@1: g_RandomGenerator.Generate(_iv, (unsigned)_ivSize); rlm@1: return S_OK; rlm@1: } rlm@1: rlm@1: STDMETHODIMP CEncoder::WriteCoderProperties(ISequentialOutStream *outStream) rlm@1: { rlm@1: // _key.Init(); rlm@1: for (UInt32 i = _ivSize; i < sizeof(_iv); i++) rlm@1: _iv[i] = 0; rlm@1: rlm@1: UInt32 ivSize = _ivSize; rlm@1: rlm@1: // _key.NumCyclesPower = 0x3F; rlm@1: _key.NumCyclesPower = 19; rlm@1: rlm@1: Byte firstByte = (Byte)(_key.NumCyclesPower | rlm@1: (((_key.SaltSize == 0) ? 0 : 1) << 7) | rlm@1: (((ivSize == 0) ? 0 : 1) << 6)); rlm@1: RINOK(outStream->Write(&firstByte, 1, NULL)); rlm@1: if (_key.SaltSize == 0 && ivSize == 0) rlm@1: return S_OK; rlm@1: Byte saltSizeSpec = (Byte)((_key.SaltSize == 0) ? 0 : (_key.SaltSize - 1)); rlm@1: Byte ivSizeSpec = (Byte)((ivSize == 0) ? 0 : (ivSize - 1)); rlm@1: Byte secondByte = (Byte)(((saltSizeSpec) << 4) | ivSizeSpec); rlm@1: RINOK(outStream->Write(&secondByte, 1, NULL)); rlm@1: if (_key.SaltSize > 0) rlm@1: { rlm@1: RINOK(WriteStream(outStream, _key.Salt, _key.SaltSize)); rlm@1: } rlm@1: if (ivSize > 0) rlm@1: { rlm@1: RINOK(WriteStream(outStream, _iv, ivSize)); rlm@1: } rlm@1: return S_OK; rlm@1: } rlm@1: rlm@1: HRESULT CEncoder::CreateFilter() rlm@1: { rlm@1: _aesFilter = new CAesCbcEncoder; rlm@1: return S_OK; rlm@1: } rlm@1: rlm@1: #endif rlm@1: rlm@1: STDMETHODIMP CDecoder::SetDecoderProperties2(const Byte *data, UInt32 size) rlm@1: { rlm@1: _key.Init(); rlm@1: UInt32 i; rlm@1: for (i = 0; i < sizeof(_iv); i++) rlm@1: _iv[i] = 0; rlm@1: if (size == 0) rlm@1: return S_OK; rlm@1: UInt32 pos = 0; rlm@1: Byte firstByte = data[pos++]; rlm@1: rlm@1: _key.NumCyclesPower = firstByte & 0x3F; rlm@1: if ((firstByte & 0xC0) == 0) rlm@1: return S_OK; rlm@1: _key.SaltSize = (firstByte >> 7) & 1; rlm@1: UInt32 ivSize = (firstByte >> 6) & 1; rlm@1: rlm@1: if (pos >= size) rlm@1: return E_INVALIDARG; rlm@1: Byte secondByte = data[pos++]; rlm@1: rlm@1: _key.SaltSize += (secondByte >> 4); rlm@1: ivSize += (secondByte & 0x0F); rlm@1: rlm@1: if (pos + _key.SaltSize + ivSize > size) rlm@1: return E_INVALIDARG; rlm@1: for (i = 0; i < _key.SaltSize; i++) rlm@1: _key.Salt[i] = data[pos++]; rlm@1: for (i = 0; i < ivSize; i++) rlm@1: _iv[i] = data[pos++]; rlm@1: return S_OK; rlm@1: } rlm@1: rlm@1: STDMETHODIMP CBaseCoder::CryptoSetPassword(const Byte *data, UInt32 size) rlm@1: { rlm@1: _key.Password.SetCapacity((size_t)size); rlm@1: memcpy(_key.Password, data, (size_t)size); rlm@1: return S_OK; rlm@1: } rlm@1: rlm@1: STDMETHODIMP CBaseCoder::Init() rlm@1: { rlm@1: CalculateDigest(); rlm@1: if (_aesFilter == 0) rlm@1: { rlm@1: RINOK(CreateFilter()); rlm@1: } rlm@1: CMyComPtr cp; rlm@1: RINOK(_aesFilter.QueryInterface(IID_ICryptoProperties, &cp)); rlm@1: RINOK(cp->SetKey(_key.Key, sizeof(_key.Key))); rlm@1: RINOK(cp->SetInitVector(_iv, sizeof(_iv))); rlm@1: return S_OK; rlm@1: } rlm@1: rlm@1: STDMETHODIMP_(UInt32) CBaseCoder::Filter(Byte *data, UInt32 size) rlm@1: { rlm@1: return _aesFilter->Filter(data, size); rlm@1: } rlm@1: rlm@1: HRESULT CDecoder::CreateFilter() rlm@1: { rlm@1: _aesFilter = new CAesCbcDecoder; rlm@1: return S_OK; rlm@1: } rlm@1: rlm@1: }}